{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "52d55e3f",
   "metadata": {},
   "source": [
    "<img src=\"http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png\" width=\"90px\">\n",
    "\n",
    "# Pyspark TensorFlow Inference\n",
    "\n",
    "## Image classification\n",
    "This notebook demonstrates training and distributed inference for image classification on MNIST.  \n",
    "Based on: https://www.tensorflow.org/tutorials/keras/save_and_load"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5233632d",
   "metadata": {},
   "source": [
    "Note that cuFFT/cuDNN/cuBLAS registration errors are expected (as of `tf=2.17.0`) and will not affect behavior, as noted in [this issue.](https://github.com/tensorflow/tensorflow/issues/62075)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c8b28f02",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-02-04 13:58:23.275397: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
      "2025-02-04 13:58:23.282713: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
      "2025-02-04 13:58:23.290717: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
      "2025-02-04 13:58:23.293187: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "2025-02-04 13:58:23.299616: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2025-02-04 13:58:23.677341: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
     ]
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import subprocess\n",
    "import shutil\n",
    "import os\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e2e67086",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.17.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
      "I0000 00:00:1738706304.084788 3671509 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
      "I0000 00:00:1738706304.107153 3671509 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
      "I0000 00:00:1738706304.109954 3671509 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n"
     ]
    }
   ],
   "source": [
    "print(tf.version.VERSION)\n",
    "\n",
    "# Enable GPU memory growth\n",
    "gpus = tf.config.experimental.list_physical_devices('GPU')\n",
    "if gpus:\n",
    "    try:\n",
    "        for gpu in gpus:\n",
    "            tf.config.experimental.set_memory_growth(gpu, True)\n",
    "    except RuntimeError as e:\n",
    "        print(e)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e0c7ad6",
   "metadata": {},
   "source": [
    "### Load and preprocess dataset\n",
    "\n",
    "Load MNIST and create a train/test split."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5b007f7c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((60000, 28, 28), (10000, 28, 28))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()\n",
    "train_images.shape, test_images.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7b7cedd1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((1000, 784), (1000, 784))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_labels = train_labels[:1000]\n",
    "test_labels = test_labels[:1000]\n",
    "\n",
    "train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0\n",
    "test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0\n",
    "\n",
    "train_images.shape, test_images.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "867a4403",
   "metadata": {},
   "source": [
    "### Define a model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "746d94db",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/rishic/anaconda3/envs/spark-dl-tf/lib/python3.11/site-packages/keras/src/layers/core/dense.py:87: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.\n",
      "  super().__init__(activity_regularizer=activity_regularizer, **kwargs)\n",
      "I0000 00:00:1738706304.278396 3671509 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I0000 00:00:1738706304.281131 3671509 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
      "I0000 00:00:1738706304.283741 3671509 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
      "I0000 00:00:1738706304.403175 3671509 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
      "I0000 00:00:1738706304.404296 3671509 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
      "I0000 00:00:1738706304.405232 3671509 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
      "2025-02-04 13:58:24.406153: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 40769 MB memory:  -> device: 0, name: NVIDIA RTX A6000, pci bus id: 0000:01:00.0, compute capability: 8.6\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"sequential\"</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1mModel: \"sequential\"\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\"> Layer (type)                    </span>┃<span style=\"font-weight: bold\"> Output Shape           </span>┃<span style=\"font-weight: bold\">       Param # </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
       "│ dense (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>)                   │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">512</span>)            │       <span style=\"color: #00af00; text-decoration-color: #00af00\">401,920</span> │\n",
       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
       "│ dropout (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dropout</span>)               │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">512</span>)            │             <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n",
       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
       "│ dense_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>)                 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">10</span>)             │         <span style=\"color: #00af00; text-decoration-color: #00af00\">5,130</span> │\n",
       "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1mLayer (type)                   \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape          \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
       "│ dense (\u001b[38;5;33mDense\u001b[0m)                   │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m)            │       \u001b[38;5;34m401,920\u001b[0m │\n",
       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
       "│ dropout (\u001b[38;5;33mDropout\u001b[0m)               │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m)            │             \u001b[38;5;34m0\u001b[0m │\n",
       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
       "│ dense_1 (\u001b[38;5;33mDense\u001b[0m)                 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m10\u001b[0m)             │         \u001b[38;5;34m5,130\u001b[0m │\n",
       "└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">407,050</span> (1.55 MB)\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m407,050\u001b[0m (1.55 MB)\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">407,050</span> (1.55 MB)\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m407,050\u001b[0m (1.55 MB)\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Define a simple sequential model\n",
    "def create_model():\n",
    "    model = tf.keras.Sequential([\n",
    "    keras.layers.Dense(512, activation='relu', input_shape=(784,)),\n",
    "    keras.layers.Dropout(0.2),\n",
    "    keras.layers.Dense(10)\n",
    "    ])\n",
    "\n",
    "    model.compile(optimizer='adam',\n",
    "                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
    "                metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])\n",
    "\n",
    "    return model\n",
    "\n",
    "# Create a basic model instance\n",
    "model = create_model()\n",
    "\n",
    "# Display the model's architecture\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "605d082a",
   "metadata": {},
   "source": [
    "### Save checkpoints during training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "dde1a855",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.mkdir(\"models\") if not os.path.exists(\"models\") else None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "244746be",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
      "I0000 00:00:1738706304.982690 3671754 service.cc:146] XLA service 0x7f1464019260 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:\n",
      "I0000 00:00:1738706304.982718 3671754 service.cc:154]   StreamExecutor device (0): NVIDIA RTX A6000, Compute Capability 8.6\n",
      "2025-02-04 13:58:24.999594: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.\n",
      "2025-02-04 13:58:25.043847: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:531] Loaded cuDNN version 8907\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m26s\u001b[0m 868ms/step - loss: 2.4638 - sparse_categorical_accuracy: 0.0625"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I0000 00:00:1738706305.619913 3671754 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 17ms/step - loss: 1.6323 - sparse_categorical_accuracy: 0.4913  "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-02-04 13:58:26.791107: I external/local_xla/xla/stream_executor/cuda/cuda_asm_compiler.cc:393] ptxas warning : Registers are spilled to local memory in function 'gemm_fusion_dot_33', 4 bytes spill stores, 4 bytes spill loads\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 1: val_sparse_categorical_accuracy improved from -inf to 0.76100, saving model to models/training_1/checkpoint.model.keras\n",
      "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 48ms/step - loss: 1.6179 - sparse_categorical_accuracy: 0.4965 - val_loss: 0.7533 - val_sparse_categorical_accuracy: 0.7610\n",
      "Epoch 2/10\n",
      "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - loss: 0.3965 - sparse_categorical_accuracy: 0.9062\n",
      "Epoch 2: val_sparse_categorical_accuracy improved from 0.76100 to 0.80400, saving model to models/training_1/checkpoint.model.keras\n",
      "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.4549 - sparse_categorical_accuracy: 0.8773 - val_loss: 0.6002 - val_sparse_categorical_accuracy: 0.8040\n",
      "Epoch 3/10\n",
      "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - loss: 0.4427 - sparse_categorical_accuracy: 0.8438\n",
      "Epoch 3: val_sparse_categorical_accuracy improved from 0.80400 to 0.85100, saving model to models/training_1/checkpoint.model.keras\n",
      "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.2924 - sparse_categorical_accuracy: 0.9289 - val_loss: 0.4876 - val_sparse_categorical_accuracy: 0.8510\n",
      "Epoch 4/10\n",
      "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - loss: 0.3644 - sparse_categorical_accuracy: 0.9375\n",
      "Epoch 4: val_sparse_categorical_accuracy did not improve from 0.85100\n",
      "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.2790 - sparse_categorical_accuracy: 0.9275 - val_loss: 0.4981 - val_sparse_categorical_accuracy: 0.8430\n",
      "Epoch 5/10\n",
      "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 12ms/step - loss: 0.2368 - sparse_categorical_accuracy: 0.9375\n",
      "Epoch 5: val_sparse_categorical_accuracy did not improve from 0.85100\n",
      "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.1794 - sparse_categorical_accuracy: 0.9645 - val_loss: 0.4893 - val_sparse_categorical_accuracy: 0.8450\n",
      "Epoch 6/10\n",
      "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 12ms/step - loss: 0.0830 - sparse_categorical_accuracy: 1.0000\n",
      "Epoch 6: val_sparse_categorical_accuracy improved from 0.85100 to 0.85400, saving model to models/training_1/checkpoint.model.keras\n",
      "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.1430 - sparse_categorical_accuracy: 0.9739 - val_loss: 0.4338 - val_sparse_categorical_accuracy: 0.8540\n",
      "Epoch 7/10\n",
      "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - loss: 0.1518 - sparse_categorical_accuracy: 1.0000\n",
      "Epoch 7: val_sparse_categorical_accuracy improved from 0.85400 to 0.86200, saving model to models/training_1/checkpoint.model.keras\n",
      "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.0876 - sparse_categorical_accuracy: 0.9909 - val_loss: 0.4194 - val_sparse_categorical_accuracy: 0.8620\n",
      "Epoch 8/10\n",
      "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 12ms/step - loss: 0.0209 - sparse_categorical_accuracy: 1.0000\n",
      "Epoch 8: val_sparse_categorical_accuracy improved from 0.86200 to 0.86800, saving model to models/training_1/checkpoint.model.keras\n",
      "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.0669 - sparse_categorical_accuracy: 0.9938 - val_loss: 0.4038 - val_sparse_categorical_accuracy: 0.8680\n",
      "Epoch 9/10\n",
      "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - loss: 0.0211 - sparse_categorical_accuracy: 1.0000\n",
      "Epoch 9: val_sparse_categorical_accuracy improved from 0.86800 to 0.86900, saving model to models/training_1/checkpoint.model.keras\n",
      "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.0429 - sparse_categorical_accuracy: 0.9998 - val_loss: 0.4062 - val_sparse_categorical_accuracy: 0.8690\n",
      "Epoch 10/10\n",
      "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 13ms/step - loss: 0.0283 - sparse_categorical_accuracy: 1.0000\n",
      "Epoch 10: val_sparse_categorical_accuracy did not improve from 0.86900\n",
      "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.0387 - sparse_categorical_accuracy: 0.9992 - val_loss: 0.4069 - val_sparse_categorical_accuracy: 0.8680\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.src.callbacks.history.History at 0x7f1673724c50>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "checkpoint_path = \"models/training_1/checkpoint.model.keras\"\n",
    "checkpoint_dir = os.path.dirname(checkpoint_path)\n",
    "\n",
    "# Create a callback that saves the model's weights\n",
    "cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,\n",
    "                                                 monitor='val_sparse_categorical_accuracy',\n",
    "                                                 mode='max',\n",
    "                                                 save_best_only=True,\n",
    "                                                 verbose=1)\n",
    "\n",
    "# Train the model with the new callback\n",
    "model.fit(train_images, \n",
    "          train_labels,  \n",
    "          epochs=10,\n",
    "          validation_data=(test_images, test_labels),\n",
    "          callbacks=[cp_callback])  # Pass callback to training\n",
    "\n",
    "# This may generate warnings related to saving the state of the optimizer.\n",
    "# These warnings (and similar warnings throughout this notebook)\n",
    "# are in place to discourage outdated usage, and can be ignored."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "310eae08",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['checkpoint.model.keras']"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "os.listdir(checkpoint_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50eeb6e5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: models/mnist_model/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: models/mnist_model/assets\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved artifact at 'models/mnist_model'. The following endpoints are available:\n",
      "\n",
      "* Endpoint 'serve'\n",
      "  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 784), dtype=tf.float32, name='keras_tensor')\n",
      "Output Type:\n",
      "  TensorSpec(shape=(None, 10), dtype=tf.float32, name=None)\n",
      "Captures:\n",
      "  139734758151120: TensorSpec(shape=(), dtype=tf.resource, name=None)\n",
      "  139734413261904: TensorSpec(shape=(), dtype=tf.resource, name=None)\n",
      "  139739081696528: TensorSpec(shape=(), dtype=tf.resource, name=None)\n",
      "  139734413262096: TensorSpec(shape=(), dtype=tf.resource, name=None)\n"
     ]
    }
   ],
   "source": [
    "# Export model in saved_model format\n",
    "model.export(\"models/mnist_model\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "6d3bba9e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/rishic/anaconda3/envs/spark-dl-tf/lib/python3.11/site-packages/keras/src/layers/core/dense.py:87: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.\n",
      "  super().__init__(activity_regularizer=activity_regularizer, **kwargs)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "32/32 - 0s - 10ms/step - loss: 2.3876 - sparse_categorical_accuracy: 0.0840\n",
      "Untrained model, accuracy:  8.40%\n"
     ]
    }
   ],
   "source": [
    "# Create a basic model instance\n",
    "model = create_model()\n",
    "\n",
    "# Evaluate the model\n",
    "loss, acc = model.evaluate(test_images, test_labels, verbose=2)\n",
    "print(\"Untrained model, accuracy: {:5.2f}%\".format(100 * acc))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22ad1708",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "32/32 - 0s - 704us/step - loss: 0.4062 - sparse_categorical_accuracy: 0.8690\n",
      "Restored model, accuracy: 86.90%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/rishic/anaconda3/envs/spark-dl-tf/lib/python3.11/site-packages/keras/src/saving/saving_lib.py:713: UserWarning: Skipping variable loading for optimizer 'adam', because it has 2 variables whereas the saved optimizer has 10 variables. \n",
      "  saveable.load_own_variables(weights_store.get(inner_path))\n"
     ]
    }
   ],
   "source": [
    "# Load the weights from the checkpoint\n",
    "model.load_weights(checkpoint_path)\n",
    "\n",
    "# Re-evaluate the model\n",
    "loss, acc = model.evaluate(test_images, test_labels, verbose=2)\n",
    "print(\"Restored model, accuracy: {:5.2f}%\".format(100 * acc))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c097d63",
   "metadata": {},
   "source": [
    "### Checkpoint callback options"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "cb336e89",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.mkdir(\"models/training_2\") if not os.path.exists(\"models/training_2\") else None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "750b6deb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 5: saving model to models/training_2/cp-0005.weights.h5\n",
      "\n",
      "Epoch 10: saving model to models/training_2/cp-0010.weights.h5\n",
      "\n",
      "Epoch 15: saving model to models/training_2/cp-0015.weights.h5\n",
      "\n",
      "Epoch 20: saving model to models/training_2/cp-0020.weights.h5\n",
      "\n",
      "Epoch 25: saving model to models/training_2/cp-0025.weights.h5\n",
      "\n",
      "Epoch 30: saving model to models/training_2/cp-0030.weights.h5\n",
      "\n",
      "Epoch 35: saving model to models/training_2/cp-0035.weights.h5\n",
      "\n",
      "Epoch 40: saving model to models/training_2/cp-0040.weights.h5\n",
      "\n",
      "Epoch 45: saving model to models/training_2/cp-0045.weights.h5\n",
      "\n",
      "Epoch 50: saving model to models/training_2/cp-0050.weights.h5\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.src.callbacks.history.History at 0x7f1672f47510>"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Include the epoch in the file name (uses `str.format`)\n",
    "checkpoint_path = \"models/training_2/cp-{epoch:04d}.weights.h5\"\n",
    "checkpoint_dir = os.path.dirname(checkpoint_path)\n",
    "\n",
    "batch_size = 32\n",
    "\n",
    "# Calculate the number of batches per epoch\n",
    "import math\n",
    "n_batches = len(train_images) / batch_size\n",
    "n_batches = math.ceil(n_batches)    # round up the number of batches to the nearest whole integer\n",
    "\n",
    "# Create a callback that saves the model's weights every 5 epochs\n",
    "cp_callback = tf.keras.callbacks.ModelCheckpoint(\n",
    "    filepath=checkpoint_path, \n",
    "    verbose=1, \n",
    "    save_weights_only=True,\n",
    "    save_freq=5*n_batches)\n",
    "\n",
    "# Create a new model instance\n",
    "model = create_model()\n",
    "\n",
    "# Save the weights using the `checkpoint_path` format\n",
    "model.save_weights(checkpoint_path.format(epoch=0))\n",
    "\n",
    "# Train the model with the new callback\n",
    "model.fit(train_images, \n",
    "          train_labels,\n",
    "          epochs=50, \n",
    "          batch_size=batch_size, \n",
    "          callbacks=[cp_callback],\n",
    "          validation_data=(test_images, test_labels),\n",
    "          verbose=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "1c43fd3d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['cp-0000.weights.h5',\n",
       " 'cp-0015.weights.h5',\n",
       " 'cp-0010.weights.h5',\n",
       " 'cp-0035.weights.h5',\n",
       " 'cp-0020.weights.h5',\n",
       " 'cp-0040.weights.h5',\n",
       " 'cp-0050.weights.h5',\n",
       " 'cp-0005.weights.h5',\n",
       " 'cp-0045.weights.h5',\n",
       " 'cp-0025.weights.h5',\n",
       " 'cp-0030.weights.h5']"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "os.listdir(checkpoint_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "0d7ae715",
   "metadata": {},
   "outputs": [],
   "source": [
    "latest = \"models/training_2/cp-0030.weights.h5\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "d345c6f7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "32/32 - 0s - 11ms/step - loss: 0.4827 - sparse_categorical_accuracy: 0.8740\n",
      "Restored model, accuracy: 87.40%\n"
     ]
    }
   ],
   "source": [
    "# Create a new model instance\n",
    "model = create_model()\n",
    "\n",
    "# Load the previously saved weights\n",
    "model.load_weights(latest)\n",
    "\n",
    "# Re-evaluate the model from the latest checkpoint\n",
    "loss, acc = model.evaluate(test_images, test_labels, verbose=2)\n",
    "print(\"Restored model, accuracy: {:5.2f}%\".format(100 * acc))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a86f4700",
   "metadata": {},
   "source": [
    "## PySpark"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "7fcf07bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.functions import predict_batch_udf\n",
    "from pyspark.sql.functions import struct, col, array, pandas_udf\n",
    "from pyspark.sql.types import *\n",
    "from pyspark.sql import SparkSession\n",
    "from pyspark import SparkConf\n",
    "import pandas as pd\n",
    "import json"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "50f02919",
   "metadata": {},
   "source": [
    "Check the cluster environment to handle any platform-specific Spark configurations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "4c81d510",
   "metadata": {},
   "outputs": [],
   "source": [
    "on_databricks = os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False)\n",
    "on_dataproc = os.environ.get(\"DATAPROC_IMAGE_VERSION\", False)\n",
    "on_standalone = not (on_databricks or on_dataproc)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c58f4df7",
   "metadata": {},
   "source": [
    "#### Create Spark Session\n",
    "\n",
    "For local standalone clusters, we'll connect to the cluster and create the Spark Session.  \n",
    "For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "2c022c24",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "25/02/04 13:58:33 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n",
      "25/02/04 13:58:33 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n",
      "Setting default log level to \"WARN\".\n",
      "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
      "25/02/04 13:58:33 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
     ]
    }
   ],
   "source": [
    "conf = SparkConf()\n",
    "\n",
    "if 'spark' not in globals():\n",
    "    if on_standalone:\n",
    "        import socket\n",
    "        \n",
    "        conda_env = os.environ.get(\"CONDA_PREFIX\")\n",
    "        hostname = socket.gethostname()\n",
    "        conf.setMaster(f\"spark://{hostname}:7077\")\n",
    "        conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n",
    "        conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n",
    "    elif on_dataproc:\n",
    "        conf.set(\"spark.executorEnv.TF_GPU_ALLOCATOR\", \"cuda_malloc_async\")\n",
    "\n",
    "    conf.set(\"spark.executor.cores\", \"8\")\n",
    "    conf.set(\"spark.task.resource.gpu.amount\", \"0.125\")\n",
    "    conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n",
    "    conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n",
    "    conf.set(\"spark.python.worker.reuse\", \"true\")\n",
    "\n",
    "conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"1000\")\n",
    "spark = SparkSession.builder.appName(\"spark-dl-examples\").config(conf=conf).getOrCreate()\n",
    "sc = spark.sparkContext"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c81d0b1b",
   "metadata": {},
   "source": [
    "### Create Spark Dataframe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "49ff5203",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1000, 784)"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# numpy array to pandas DataFrame\n",
    "test_pdf = pd.DataFrame(test_images)\n",
    "test_pdf.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "182ee0c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = spark.createDataFrame(test_pdf).repartition(8)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4e1c7ec-64fa-43c4-9bcf-0868a401d1f2",
   "metadata": {},
   "source": [
    "### Save as Parquet (784 columns of float)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "0061c39a-0871-429e-a4ff-751d26bf4b04",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "25/02/04 13:58:35 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.\n",
      "[Stage 0:>                                                          (0 + 8) / 8]\r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 3.05 ms, sys: 1.22 ms, total: 4.26 ms\n",
      "Wall time: 1.93 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "data_path_784 = \"spark-dl-datasets/mnist_784\"\n",
    "if on_databricks:\n",
    "    dbutils.fs.mkdirs(\"/FileStore/spark-dl-datasets\")\n",
    "    data_path_784 = \"dbfs:/FileStore/\" + data_path_784\n",
    "\n",
    "df.write.mode(\"overwrite\").parquet(data_path_784)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18315afb-3fa2-4953-9297-52c04dd70c32",
   "metadata": {},
   "source": [
    "### Save as Parquet (1 column of 784 float)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "302c73ec",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1000, 1)"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_pdf['data'] = test_pdf.values.tolist()\n",
    "pdf = test_pdf[['data']]\n",
    "pdf.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "5495901b",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = spark.createDataFrame(pdf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "5fa7faa8-c6bd-41b0-b5f7-fb121f0332e6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 875 μs, sys: 187 μs, total: 1.06 ms\n",
      "Wall time: 196 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "data_path_1 = \"spark-dl-datasets/mnist_1\"\n",
    "if on_databricks:\n",
    "    dbutils.fs.mkdirs(\"/FileStore/spark-dl-datasets\")\n",
    "    data_path_1 = \"dbfs:/FileStore/\" + data_path_1\n",
    "\n",
    "df.write.mode(\"overwrite\").parquet(data_path_1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b366aaeb",
   "metadata": {},
   "source": [
    "## Inference using Spark DL API\n",
    "\n",
    "Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):\n",
    "\n",
    "- predict_batch_fn uses Tensorflow APIs to load the model and return a predict function which operates on numpy arrays \n",
    "- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4238fb28-d002-4b4d-9aa1-8af1fbd5d569",
   "metadata": {},
   "source": [
    "### 1 column of 784 float"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "b9cf62f8-96b2-4716-80bd-bb93d5f939bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = \"{}/models/training_1/checkpoint.model.keras\".format(os.getcwd())\n",
    "\n",
    "# For cloud environments, copy the model to the distributed file system.\n",
    "if on_databricks:\n",
    "    dbutils.fs.mkdirs(\"/FileStore/spark-dl-models\")\n",
    "    dbfs_model_path = \"/dbfs/FileStore/spark-dl-models/checkpoint.model.keras\"\n",
    "    shutil.copy(model_path, dbfs_model_path)\n",
    "    model_path = dbfs_model_path\n",
    "elif on_dataproc:\n",
    "    # GCS is mounted at /mnt/gcs by the init script\n",
    "    models_dir = \"/mnt/gcs/spark-dl/models\"\n",
    "    os.mkdir(models_dir) if not os.path.exists(models_dir) else None\n",
    "    gcs_model_path = models_dir + \"/checkpoint.model.keras\"\n",
    "    shutil.copy(model_path, gcs_model_path)\n",
    "    model_path = gcs_model_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "b81fa297-d9d0-4600-880d-dbdcdf8bccc6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict_batch_fn():\n",
    "    import tensorflow as tf\n",
    "\n",
    "    # Enable GPU memory growth to avoid CUDA OOM\n",
    "    gpus = tf.config.experimental.list_physical_devices('GPU')\n",
    "    if gpus:\n",
    "        try:\n",
    "            for gpu in gpus:\n",
    "                tf.config.experimental.set_memory_growth(gpu, True)\n",
    "        except RuntimeError as e:\n",
    "            print(e)\n",
    "\n",
    "    model = tf.keras.models.load_model(model_path)\n",
    "    def predict(inputs: np.ndarray) -> np.ndarray:\n",
    "        return model.predict(inputs)\n",
    "        \n",
    "    return predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "72a689bd-dd82-492e-8740-1738a215325f",
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist = predict_batch_udf(predict_batch_fn,\n",
    "                          return_type=ArrayType(FloatType()),\n",
    "                          batch_size=128,\n",
    "                          input_tensor_shapes=[[784]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "60a70150-26b1-4145-9e7d-6e17389216b7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = spark.read.parquet(data_path_1)\n",
    "len(df.columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "e027f0d2-0f65-47b7-a562-2f0965faceec",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------------------+\n",
      "|                data|\n",
      "+--------------------+\n",
      "|[0.0, 0.0, 0.0, 0...|\n",
      "|[0.0, 0.0, 0.0, 0...|\n",
      "|[0.0, 0.0, 0.0, 0...|\n",
      "|[0.0, 0.0, 0.0, 0...|\n",
      "|[0.0, 0.0, 0.0, 0...|\n",
      "+--------------------+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.show(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "f0c3fb2e-469e-47bc-b948-8f6b0d7f6513",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Stage 6:===================================================>       (7 + 1) / 8]\r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 24.1 ms, sys: 11 ms, total: 35.2 ms\n",
      "Wall time: 5.52 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# first pass caches model/fn\n",
    "preds = df.withColumn(\"preds\", mnist(struct(df.columns))).collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "cdfa229a-f4a9-4c11-a410-de4a21c02c82",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 21.1 ms, sys: 14.7 ms, total: 35.8 ms\n",
      "Wall time: 277 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "preds = df.withColumn(\"preds\", mnist(*df.columns)).collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "5586ce49-6f93-4343-9b66-0dbb64972179",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 37.1 ms, sys: 8.46 ms, total: 45.6 ms\n",
      "Wall time: 216 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "preds = df.withColumn(\"preds\", mnist(*[col(c) for c in df.columns])).collect()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "004f1599-3c62-499e-9fd8-ed5cb0c90de4",
   "metadata": {
    "tags": []
   },
   "source": [
    "#### Check predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "4f947dc0-6b18-4605-810b-e83250a161db",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>data</th>\n",
       "      <th>preds</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\n",
       "      <td>[-4.6654954, -2.4895542, -0.5886033, 13.380537...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\n",
       "      <td>[-2.273215, -7.5127845, 1.1983701, -3.540661, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\n",
       "      <td>[-2.28909, 0.8308607, 0.31311005, 1.1683632, -...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\n",
       "      <td>[-1.0551968, -6.5028114, 12.420729, 0.45280308...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\n",
       "      <td>[-3.7887802, 3.9983602, -1.5343361, -0.3698440...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\n",
       "      <td>[-4.499274, -1.7618222, 1.1183227, 3.946932, -...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\n",
       "      <td>[-2.7540536, 4.8684144, 0.25152916, -0.4730078...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\n",
       "      <td>[-1.8887109, 0.02717152, -6.0508857, 0.0875094...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\n",
       "      <td>[0.9541265, -2.113048, -1.7508972, -5.4303794,...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\n",
       "      <td>[-1.612412, -0.7655784, -4.473859, 2.0609212, ...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                data  \\\n",
       "0  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \n",
       "1  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \n",
       "2  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \n",
       "3  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \n",
       "4  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \n",
       "5  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \n",
       "6  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \n",
       "7  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \n",
       "8  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \n",
       "9  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \n",
       "\n",
       "                                               preds  \n",
       "0  [-4.6654954, -2.4895542, -0.5886033, 13.380537...  \n",
       "1  [-2.273215, -7.5127845, 1.1983701, -3.540661, ...  \n",
       "2  [-2.28909, 0.8308607, 0.31311005, 1.1683632, -...  \n",
       "3  [-1.0551968, -6.5028114, 12.420729, 0.45280308...  \n",
       "4  [-3.7887802, 3.9983602, -1.5343361, -0.3698440...  \n",
       "5  [-4.499274, -1.7618222, 1.1183227, 3.946932, -...  \n",
       "6  [-2.7540536, 4.8684144, 0.25152916, -0.4730078...  \n",
       "7  [-1.8887109, 0.02717152, -6.0508857, 0.0875094...  \n",
       "8  [0.9541265, -2.113048, -1.7508972, -5.4303794,...  \n",
       "9  [-1.612412, -0.7655784, -4.473859, 2.0609212, ...  "
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "preds = df.withColumn(\"preds\", mnist(*df.columns)).limit(10).toPandas()\n",
    "preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "de4964e0-d1f8-4753-afa1-a8f95ca3f151",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-4.6654954, -2.4895542, -0.5886033, 13.380537 , -6.652599 ,\n",
       "        2.8400383, -7.9901567, -0.7500452, -2.4487166, -4.349809 ],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sample = preds.iloc[0]\n",
    "sample.preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "44e9a874-e301-4b72-8df7-bf1c5133c287",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "c60e5af4-fc1e-4575-a717-f304664235be",
   "metadata": {},
   "outputs": [],
   "source": [
    "prediction = np.argmax(sample.preds)\n",
    "img = np.array(sample.data).reshape(28,28)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb45ecc9-d376-40c4-ad7b-2bd08ca5aaf6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGzCAYAAABpdMNsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAkfElEQVR4nO3dfXQUdZ7v8U/nqSEkaR7yLAFCFHRAcAYly/AgSiQEZUCYGUG9F7gziJiggI6KR0Udzsksrg7qIHjcHVhHEGWOyMoiDg9JGBRwwTCIM2QhJ0g4kIBcSYcAIaR/9w+uvbQkQDUdfkl4v86pc+iq37fqm6Lgk+qqrnYZY4wAALjKwmw3AAC4NhFAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAgAPdunXTpEmT/K8LCwvlcrlUWFgYsm24XC698MILIVsf0FwRQGgxlixZIpfL5Z/atGmjHj16KC8vT5WVlbbbc2TNmjUtJmTefvtt3X777UpKSpLb7VZ6eromT56s/fv3224NLVyE7QYAp1566SWlp6fr9OnT2rx5sxYuXKg1a9Zo9+7dio6Ovqq9DBkyRKdOnVJUVJSjujVr1mjBggUNhtCpU6cUEdF8/mkWFxcrPT1dP/vZz9ShQweVlZXp7bff1urVq/W3v/1NqamptltEC9V8jnLgMuXk5OjWW2+VJP36179Wp06d9Oqrr2rVqlWaMGFCgzU1NTVq165dyHsJCwtTmzZtQrrOUK/vSr355psXzBszZoxuvfVWvfPOO3r66actdIXWgLfg0OLdeeedkqSysjJJ0qRJkxQTE6PS0lKNHDlSsbGxeuCBByRJPp9P8+fPV69evdSmTRslJSVp6tSp+u677wLWaYzR3Llz1blzZ0VHR+uOO+7Q119/fcG2G7sGtG3bNo0cOVIdOnRQu3bt1KdPH7322mv+/hYsWCBJAW8pfq+ha0DFxcXKyclRXFycYmJiNGzYMG3dujVgzPdvUX722WeaNWuWEhIS1K5dO9177706evRowNiqqirt2bNHVVVVl7OLL9CtWzdJ0vHjx4OqByTOgNAKlJaWSpI6derkn3f27FllZ2dr0KBB+pd/+Rf/W3NTp07VkiVLNHnyZD366KMqKyvTH/7wBxUXF+uzzz5TZGSkJOn555/X3LlzNXLkSI0cOVJffvmlhg8frjNnzlyyn3Xr1umee+5RSkqKHnvsMSUnJ+sf//iHVq9erccee0xTp07VoUOHtG7dOv3pT3+65Pq+/vprDR48WHFxcXryyScVGRmpt956S0OHDlVRUZEyMzMDxk+fPl0dOnTQnDlztH//fs2fP195eXl6//33/WNWrlypyZMna/HixQE3VVzMsWPHVF9frwMHDuill16SJA0bNuyyaoEGGaCFWLx4sZFk1q9fb44ePWrKy8vN8uXLTadOnUzbtm3NwYMHjTHGTJw40UgyTz/9dED9X//6VyPJLF26NGD+2rVrA+YfOXLEREVFmbvvvtv4fD7/uGeeecZIMhMnTvTPKygoMJJMQUGBMcaYs2fPmvT0dNO1a1fz3XffBWzn/HXl5uaaxv75STJz5szxvx4zZoyJiooypaWl/nmHDh0ysbGxZsiQIRfsn6ysrIBtzZw504SHh5vjx49fMHbx4sUN9tAQt9ttJBlJplOnTub111+/7FqgIbwFhxYnKytLCQkJSktL0/jx4xUTE6OVK1fquuuuCxg3bdq0gNcrVqyQx+PRXXfdpW+//dY/9evXTzExMSooKJAkrV+/XmfOnNH06dMD3hqbMWPGJXsrLi5WWVmZZsyYofbt2wcsO39dl6u+vl5/+ctfNGbMGHXv3t0/PyUlRffff782b94sr9cbUPPQQw8FbGvw4MGqr6/XN9984583adIkGWMu++xHkj755BOtWbNGr7zyirp06aKamhrHPw9wPt6CQ4uzYMEC9ejRQxEREUpKSlLPnj0VFhb4u1RERIQ6d+4cMG/v3r2qqqpSYmJig+s9cuSIJPn/o77hhhsClickJKhDhw4X7e37twN79+59+T/QRRw9elQnT55Uz549L1h20003yefzqby8XL169fLP79KlS8C473v+4XUup+644w5J524CGT16tHr37q2YmBjl5eVd0Xpx7SKA0OL079/ffxdcY9xu9wWh5PP5lJiYqKVLlzZYk5CQELIebQoPD29wvjEmZNvIyMjQj3/8Yy1dupQAQtAIIFwzMjIytH79eg0cOFBt27ZtdFzXrl0lnTtjOv9tr6NHj17yLCIjI0OStHv3bmVlZTU67nLfjktISFB0dLRKSkouWLZnzx6FhYUpLS3tstYVaqdOnVJtba2VbaN14BoQrhm//OUvVV9fr9/+9rcXLDt79qz/luKsrCxFRkbqjTfeCDhrmD9//iW38ZOf/ETp6emaP3/+Bbcon7+u7z+TdKnbmMPDwzV8+HCtWrUq4MkDlZWVWrZsmQYNGqS4uLhL9vVDl3sb9tmzZxsM3S+++EJfffXVJc9EgYvhDAjXjNtvv11Tp05Vfn6+du7cqeHDhysyMlJ79+7VihUr9Nprr+nnP/+5EhIS9MQTTyg/P1/33HOPRo4cqeLiYn3yySeKj4+/6DbCwsK0cOFCjRo1SrfccosmT56slJQU7dmzR19//bU+/fRTSVK/fv0kSY8++qiys7MVHh6u8ePHN7jOuXPnat26dRo0aJAeeeQRRURE6K233lJtba3mzZsX1L643NuwT5w4obS0NN13333q1auX2rVrp6+++kqLFy+Wx+PRc889F9T2AYkAwjVm0aJF6tevn9566y0988wzioiIULdu3fTggw9q4MCB/nFz585VmzZttGjRIhUUFCgzM1N/+ctfdPfdd19yG9nZ2SooKNCLL76oV155RT6fTxkZGZoyZYp/zNixYzV9+nQtX75c7777rowxjQZQr1699Ne//lWzZ89Wfn6+fD6fMjMz9e67717wGaBQi46O1q9//WsVFBToz3/+s06dOqXU1FRNmDBBzz77rP8DqUAwXCaUVyYBALhMXAMCAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMCKZvc5IJ/Pp0OHDik2NjaopwcDAOwyxqi6ulqpqakXPJPxfM0ugA4dOmTt2VYAgNApLy+/4Kn052t2ARQbGytJGqSRilCk5W4AAE6dVZ02a43///PGNFkALViwQC+//LIqKirUt29fvfHGG+rfv/8l675/2y1CkYpwEUAA0OL8/+frXOoySpPchPD+++9r1qxZmjNnjr788kv17dtX2dnZ/i/8AgCgSQLo1Vdf1ZQpUzR58mT96Ec/0qJFixQdHa0//vGPTbE5AEALFPIAOnPmjHbs2BHwZVxhYWHKysrSli1bLhhfW1srr9cbMAEAWr+QB9C3336r+vp6JSUlBcxPSkpSRUXFBePz8/Pl8Xj8E3fAAcC1wfoHUWfPnq2qqir/VF5ebrslAMBVEPK74OLj4xUeHq7KysqA+ZWVlUpOTr5gvNvtltvtDnUbAIBmLuRnQFFRUerXr582bNjgn+fz+bRhwwYNGDAg1JsDALRQTfI5oFmzZmnixIm69dZb1b9/f82fP181NTWaPHlyU2wOANACNUkA3XfffTp69Kief/55VVRU6JZbbtHatWsvuDEBAHDtchljjO0mzuf1euXxeDRUo3kSAgC0QGdNnQq1SlVVVYqLi2t0nPW74AAA1yYCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWBFhuwHYZ37aN7i6cOe/v0RWeh3XlP7vRMc1vu6nHNdI0p7b/+i4JtzlfD9MPTjAcU3hp7c4run2nzWOayRJW3cFVwc4wBkQAMAKAggAYEXIA+iFF16Qy+UKmG688cZQbwYA0MI1yTWgXr16af369f+zkQguNQEAAjVJMkRERCg5ObkpVg0AaCWa5BrQ3r17lZqaqu7du+uBBx7QgQMHGh1bW1srr9cbMAEAWr+QB1BmZqaWLFmitWvXauHChSorK9PgwYNVXV3d4Pj8/Hx5PB7/lJaWFuqWAADNUMgDKCcnR7/4xS/Up08fZWdna82aNTp+/Lg++OCDBsfPnj1bVVVV/qm8vDzULQEAmqEmvzugffv26tGjh/bt29fgcrfbLbfb3dRtAACamSb/HNCJEydUWlqqlJSUpt4UAKAFCXkAPfHEEyoqKtL+/fv1+eef695771V4eLgmTJgQ6k0BAFqwkL8Fd/DgQU2YMEHHjh1TQkKCBg0apK1btyohISHUmwIAtGAuY4yx3cT5vF6vPB6Phmq0IlyRttuxqubnmY5rKm91flK7dsLLjmskqUtEW8c1/2v/XY5r/tRtneManFN8xhdU3eOP5zmuif5wW1DbQutz1tSpUKtUVVWluLi4RsfxLDgAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIKHkV4lR/J+6rim8OlXHNdEu6Ic1zR339afclzTxhXc71Z1cv7PYcaBexzX/DLxvxzX3B1d5bgmWPvqah3XPDH4l45rzpYfdFyD5o+HkQIAmjUCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsiLDdwLXCF+68pjU+2frlYz9yXLNhxiDHNfVtg/vd6rvrnT+B/br/POy45s2EcY5r7v7zHx3XBGvsf011XNPt+P7QN4JWjTMgAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCh5FeJan/+jfHNR88kui4Jjv6gOOanDlPOK6RpLoYl+Oa6/7joOOaiP07nNc4rjgnOYia+iBqKu/5aRBVV8+uny5xXDMmiAes+qqrHdeg9eAMCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCs4GGkV4mvpsZxzTs90xzXvJ0z1nFNfEGx4xpJ8p0+7bjmbFBbunrCExIc13x3V4bjmscf/sBxDdDacAYEALCCAAIAWOE4gDZt2qRRo0YpNTVVLpdLH330UcByY4yef/55paSkqG3btsrKytLevXtD1S8AoJVwHEA1NTXq27evFixY0ODyefPm6fXXX9eiRYu0bds2tWvXTtnZ2TodxPUCAEDr5fgmhJycHOXk5DS4zBij+fPn69lnn9Xo0aMlSe+8846SkpL00Ucfafz48VfWLQCg1QjpNaCysjJVVFQoKyvLP8/j8SgzM1NbtmxpsKa2tlZerzdgAgC0fiENoIqKCklSUlJSwPykpCT/sh/Kz8+Xx+PxT2lpzm89BgC0PNbvgps9e7aqqqr8U3l5ue2WAABXQUgDKDk5WZJUWVkZML+ystK/7Ifcbrfi4uICJgBA6xfSAEpPT1dycrI2bNjgn+f1erVt2zYNGDAglJsCALRwju+CO3HihPbt2+d/XVZWpp07d6pjx47q0qWLZsyYoblz5+qGG25Qenq6nnvuOaWmpmrMmDGh7BsA0MI5DqDt27frjjvu8L+eNWuWJGnixIlasmSJnnzySdXU1Oihhx7S8ePHNWjQIK1du1Zt2rQJXdcAgBbPZYwxtps4n9frlcfj0VCNVoQr0nY7aKHC23uCqnt8x2bHNUPanAlqW1eDT76g6n73bV/HNduGpTiuqf/2mOMaNH9nTZ0KtUpVVVUXva5v/S44AMC1iQACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACscfx0D0BKUPdorqLohbTaGuBO7VtXEB1X3ed+oIKp4sjWc4QwIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKzgYaRAK3ZX28NB1c2d9YDjmrqYoDblWGLxWcc1bT7+ogk6wZXiDAgAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArOBhpGiV0jacDKpuxyTnNf3cQW3qqogJC665HY+/EeJOQmfOkR87rtnxMb9rN0f8rQAArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFTyMFK2S67OdQdW9NOp+xzW1STGOa6of9zqu+eyW5Y5rWqNnE7Y7rrnzwUeD2pbn3a1B1eHycAYEALCCAAIAWOE4gDZt2qRRo0YpNTVVLpdLH330UcDySZMmyeVyBUwjRowIVb8AgFbCcQDV1NSob9++WrBgQaNjRowYocOHD/un995774qaBAC0Po5vQsjJyVFOTs5Fx7jdbiUnJwfdFACg9WuSa0CFhYVKTExUz549NW3aNB07dqzRsbW1tfJ6vQETAKD1C3kAjRgxQu+88442bNigf/7nf1ZRUZFycnJUX1/f4Pj8/Hx5PB7/lJaWFuqWAADNUMg/BzR+/Hj/n2+++Wb16dNHGRkZKiws1LBhwy4YP3v2bM2aNcv/2uv1EkIAcA1o8tuwu3fvrvj4eO3bt6/B5W63W3FxcQETAKD1a/IAOnjwoI4dO6aUlJSm3hQAoAVx/BbciRMnAs5mysrKtHPnTnXs2FEdO3bUiy++qHHjxik5OVmlpaV68skndf311ys7OzukjQMAWjbHAbR9+3bdcccd/tffX7+ZOHGiFi5cqF27dunf//3fdfz4caWmpmr48OH67W9/K7fbHbquAQAtnuMAGjp0qIwxjS7/9NNPr6ghwKb6r0sc10R87Xw7HQpcjmtGRf3Ucc3+P/VwXCNJn2QudFzTOaJtUNtyKtIV7rjmdMfgrjZ4gqrC5eJZcAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALAi5F/JDeAyXOSJ8o2W1NY6run6y68c10jSnW/NdFzz3/csCmpbuHZxBgQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVvAwUqAVc0VGBVfXtj7EnYTOrjPOe0vcXtMEneBKcQYEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFbwMFKgFSt545ag6v572MLQNhJCMx6f7rgm+vNtTdAJrhRnQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQ8jRasUFhsbXF17T4g7adiRu9Ic19w1/TPHNf+RuMBxzTlX53fTD04kOq6J2/KN45qzjitwNXAGBACwggACAFjhKIDy8/N12223KTY2VomJiRozZoxKSkoCxpw+fVq5ubnq1KmTYmJiNG7cOFVWVoa0aQBAy+cogIqKipSbm6utW7dq3bp1qqur0/Dhw1VTU+MfM3PmTH388cdasWKFioqKdOjQIY0dOzbkjQMAWjZHNyGsXbs24PWSJUuUmJioHTt2aMiQIaqqqtK//du/admyZbrzzjslSYsXL9ZNN92krVu36p/+6Z9C1zkAoEW7omtAVVVVkqSOHTtKknbs2KG6ujplZWX5x9x4443q0qWLtmzZ0uA6amtr5fV6AyYAQOsXdAD5fD7NmDFDAwcOVO/evSVJFRUVioqKUvv27QPGJiUlqaKiosH15Ofny+Px+Ke0NOe3pwIAWp6gAyg3N1e7d+/W8uXLr6iB2bNnq6qqyj+Vl5df0foAAC1DUB9EzcvL0+rVq7Vp0yZ17tzZPz85OVlnzpzR8ePHA86CKisrlZyc3OC63G633G53MG0AAFowR2dAxhjl5eVp5cqV2rhxo9LT0wOW9+vXT5GRkdqwYYN/XklJiQ4cOKABAwaEpmMAQKvg6AwoNzdXy5Yt06pVqxQbG+u/ruPxeNS2bVt5PB796le/0qxZs9SxY0fFxcVp+vTpGjBgAHfAAQACOAqghQsXSpKGDh0aMH/x4sWaNGmSJOn3v/+9wsLCNG7cONXW1io7O1tvvvlmSJoFALQeLmOMsd3E+bxerzwej4ZqtCJckbbbuSaE9b0pqLo9uTGOa5LT/q/jmiMlCY5rJt9Z6LhGkp7q9HVQdQhOn88nOa7p8ouvQt8IQuqsqVOhVqmqqkpxcXGNjuNZcAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALAiqG9ERfPl6tfLcU3b3x8Jalv/nfFuUHWO9bk6m2nuak2d45pIV3hQ26qsr3VcM+dQjuOazq8F1x9aB86AAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKHkbaytR0jXFc8373fw1ya1FB1jU9n3xB1c08NNhxzW8S1zuuyf4813FNbGG045rqbo5LJEnps7cEUVXtuCJMO4PYDloLzoAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoeRtrKRH+4zXHNLw4/HNS2jv64neMaXxDPL61z/nxVvf1//uC8SFLpbacd10z78VTHNek7dzmukTGOS+KdbwW4ajgDAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArXMYE8YTDJuT1euXxeDRUoxXhirTdDgDAobOmToVapaqqKsXFxTU6jjMgAIAVBBAAwApHAZSfn6/bbrtNsbGxSkxM1JgxY1RSUhIwZujQoXK5XAHTww8H930zAIDWy1EAFRUVKTc3V1u3btW6detUV1en4cOHq6amJmDclClTdPjwYf80b968kDYNAGj5HH0j6tq1awNeL1myRImJidqxY4eGDBninx8dHa3k5OTQdAgAaJWu6BpQVVWVJKljx44B85cuXar4+Hj17t1bs2fP1smTJxtdR21trbxeb8AEAGj9HJ0Bnc/n82nGjBkaOHCgevfu7Z9///33q2vXrkpNTdWuXbv01FNPqaSkRB9++GGD68nPz9eLL74YbBsAgBYq6M8BTZs2TZ988ok2b96szp07Nzpu48aNGjZsmPbt26eMjIwLltfW1qq2ttb/2uv1Ki0tjc8BAUALdbmfAwrqDCgvL0+rV6/Wpk2bLho+kpSZmSlJjQaQ2+2W2+0Opg0AQAvmKICMMZo+fbpWrlypwsJCpaenX7Jm586dkqSUlJSgGgQAtE6OAig3N1fLli3TqlWrFBsbq4qKCkmSx+NR27ZtVVpaqmXLlmnkyJHq1KmTdu3apZkzZ2rIkCHq06dPk/wAAICWydE1IJfL1eD8xYsXa9KkSSovL9eDDz6o3bt3q6amRmlpabr33nv17LPPXvR9wPPxLDgAaNma5BrQpbIqLS1NRUVFTlYJALhG8Sw4AIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVEbYb+CFjjCTprOokY7kZAIBjZ1Un6X/+P29Mswug6upqSdJmrbHcCQDgSlRXV8vj8TS63GUuFVFXmc/n06FDhxQbGyuXyxWwzOv1Ki0tTeXl5YqLi7PUoX3sh3PYD+ewH85hP5zTHPaDMUbV1dVKTU1VWFjjV3qa3RlQWFiYOnfufNExcXFx1/QB9j32wznsh3PYD+ewH86xvR8udubzPW5CAABYQQABAKxoUQHkdrs1Z84cud1u261YxX44h/1wDvvhHPbDOS1pPzS7mxAAANeGFnUGBABoPQggAIAVBBAAwAoCCABgBQEEALCixQTQggUL1K1bN7Vp00aZmZn64osvbLd01b3wwgtyuVwB04033mi7rSa3adMmjRo1SqmpqXK5XProo48Clhtj9PzzzyslJUVt27ZVVlaW9u7da6fZJnSp/TBp0qQLjo8RI0bYabaJ5Ofn67bbblNsbKwSExM1ZswYlZSUBIw5ffq0cnNz1alTJ8XExGjcuHGqrKy01HHTuJz9MHTo0AuOh4cffthSxw1rEQH0/vvva9asWZozZ46+/PJL9e3bV9nZ2Tpy5Ijt1q66Xr166fDhw/5p8+bNtltqcjU1Nerbt68WLFjQ4PJ58+bp9ddf16JFi7Rt2za1a9dO2dnZOn369FXutGldaj9I0ogRIwKOj/fee+8qdtj0ioqKlJubq61bt2rdunWqq6vT8OHDVVNT4x8zc+ZMffzxx1qxYoWKiop06NAhjR071mLXoXc5+0GSpkyZEnA8zJs3z1LHjTAtQP/+/U1ubq7/dX19vUlNTTX5+fkWu7r65syZY/r27Wu7DaskmZUrV/pf+3w+k5ycbF5++WX/vOPHjxu3223ee+89Cx1eHT/cD8YYM3HiRDN69Ggr/dhy5MgRI8kUFRUZY8793UdGRpoVK1b4x/zjH/8wksyWLVtstdnkfrgfjDHm9ttvN4899pi9pi5Dsz8DOnPmjHbs2KGsrCz/vLCwMGVlZWnLli0WO7Nj7969Sk1NVffu3fXAAw/owIEDtluyqqysTBUVFQHHh8fjUWZm5jV5fBQWFioxMVE9e/bUtGnTdOzYMdstNamqqipJUseOHSVJO3bsUF1dXcDxcOONN6pLly6t+nj44X743tKlSxUfH6/evXtr9uzZOnnypI32GtXsnob9Q99++63q6+uVlJQUMD8pKUl79uyx1JUdmZmZWrJkiXr27KnDhw/rxRdf1ODBg7V7927Fxsbabs+KiooKSWrw+Ph+2bVixIgRGjt2rNLT01VaWqpnnnlGOTk52rJli8LDw223F3I+n08zZszQwIED1bt3b0nnjoeoqCi1b98+YGxrPh4a2g+SdP/996tr165KTU3Vrl279NRTT6mkpEQffvihxW4DNfsAwv/Iycnx/7lPnz7KzMxU165d9cEHH+hXv/qVxc7QHIwfP97/55tvvll9+vRRRkaGCgsLNWzYMIudNY3c3Fzt3r37mrgOejGN7YeHHnrI/+ebb75ZKSkpGjZsmEpLS5WRkXG122xQs38LLj4+XuHh4RfcxVJZWank5GRLXTUP7du3V48ePbRv3z7brVjz/THA8XGh7t27Kz4+vlUeH3l5eVq9erUKCgoCvj8sOTlZZ86c0fHjxwPGt9bjobH90JDMzExJalbHQ7MPoKioKPXr108bNmzwz/P5fNqwYYMGDBhgsTP7Tpw4odLSUqWkpNhuxZr09HQlJycHHB9er1fbtm275o+PgwcP6tixY63q+DDGKC8vTytXrtTGjRuVnp4esLxfv36KjIwMOB5KSkp04MCBVnU8XGo/NGTnzp2S1LyOB9t3QVyO5cuXG7fbbZYsWWL+/ve/m4ceesi0b9/eVFRU2G7tqnr88cdNYWGhKSsrM5999pnJysoy8fHx5siRI7Zba1LV1dWmuLjYFBcXG0nm1VdfNcXFxeabb74xxhjzu9/9zrRv396sWrXK7Nq1y4wePdqkp6ebU6dOWe48tC62H6qrq80TTzxhtmzZYsrKysz69evNT37yE3PDDTeY06dP2249ZKZNm2Y8Ho8pLCw0hw8f9k8nT570j3n44YdNly5dzMaNG8327dvNgAEDzIABAyx2HXqX2g/79u0zL730ktm+fbspKyszq1atMt27dzdDhgyx3HmgFhFAxhjzxhtvmC5dupioqCjTv39/s3XrVtstXXX33XefSUlJMVFRUea6664z9913n9m3b5/ttppcQUGBkXTBNHHiRGPMuVuxn3vuOZOUlGTcbrcZNmyYKSkpsdt0E7jYfjh58qQZPny4SUhIMJGRkaZr165mypQpre6XtIZ+fklm8eLF/jGnTp0yjzzyiOnQoYOJjo429957rzl8+LC9ppvApfbDgQMHzJAhQ0zHjh2N2+02119/vfnNb35jqqqq7Db+A3wfEADAimZ/DQgA0DoRQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAV/w/hgVLrpVGHsAAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure()\n",
    "plt.title(\"Prediction: {}\".format(prediction))\n",
    "plt.imshow(img)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39167347-0b99-4972-998c-e1230bf1d4d5",
   "metadata": {},
   "source": [
    "### 784 columns of float"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "6bea332e-f6de-494f-a0db-795d9fe3e134",
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict_batch_fn():\n",
    "    import tensorflow as tf\n",
    "    # Enable GPU memory growth\n",
    "    gpus = tf.config.experimental.list_physical_devices('GPU')\n",
    "    if gpus:\n",
    "        try:\n",
    "            for gpu in gpus:\n",
    "                tf.config.experimental.set_memory_growth(gpu, True)\n",
    "        except RuntimeError as e:\n",
    "            print(e)\n",
    "            \n",
    "    model = tf.keras.models.load_model(model_path)\n",
    "    def predict(inputs: np.ndarray) -> np.ndarray:\n",
    "        return model.predict(inputs)\n",
    "        \n",
    "    return predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "731d234c-549f-4df3-8a2b-312e63195396",
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist = predict_batch_udf(predict_batch_fn,\n",
    "                          return_type=ArrayType(FloatType()),\n",
    "                          batch_size=128,\n",
    "                          input_tensor_shapes=[[784]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a40fe207-6246-4b0e-abde-823979878d97",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "784"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = spark.read.parquet(data_path_784)\n",
    "len(df.columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "10904f12-03e7-4518-8f12-2aa11989ddf5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Stage 12:==============>                                           (2 + 6) / 8]\r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 52.5 ms, sys: 22 ms, total: 74.5 ms\n",
      "Wall time: 5.72 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "preds = df.withColumn(\"preds\", mnist(struct(*df.columns))).collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "671128df-f0f4-4f54-b35c-d63a78c7f89a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Stage 13:===========================================>              (6 + 2) / 8]\r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 49.4 ms, sys: 31.9 ms, total: 81.2 ms\n",
      "Wall time: 1.34 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "preds = df.withColumn(\"preds\", mnist(array(*df.columns))).collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "ce35deaf-7d49-4f34-9bf9-b4e6fc5761f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# should raise ValueError\n",
    "# preds = df.withColumn(\"preds\", mnist(*df.columns)).collect()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "01709833-484b-451f-9aa8-37be5b7baf14",
   "metadata": {},
   "source": [
    "### Check prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "f9119632-b284-45d7-a262-c262e034c15c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "      <th>6</th>\n",
       "      <th>7</th>\n",
       "      <th>8</th>\n",
       "      <th>9</th>\n",
       "      <th>...</th>\n",
       "      <th>775</th>\n",
       "      <th>776</th>\n",
       "      <th>777</th>\n",
       "      <th>778</th>\n",
       "      <th>779</th>\n",
       "      <th>780</th>\n",
       "      <th>781</th>\n",
       "      <th>782</th>\n",
       "      <th>783</th>\n",
       "      <th>preds</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>[-6.9618006, 1.2047814, -0.09570807, 0.0462105...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>[-5.2882323, 5.902014, -2.0389183, -1.2460864,...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>[-5.822013, -2.3333628, -2.4322102, -8.040086,...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>[-0.57203317, -1.2920653, -2.7234774, 0.914070...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>[-3.689301, 5.0702505, -0.23930073, -0.7988689...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>[8.268821, -2.070008, 1.722378, -1.8471404, -8...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>[5.59269, -3.1613479, 0.4734843, -0.7772096, -...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>[1.9852623, -5.166985, 0.86473066, -6.491789, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>[-2.800528, -4.2984514, 10.887824, -3.1346364,...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>[-3.7827752, -4.51145, -5.354035, 9.399383, -6...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>10 rows × 785 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     0    1    2    3    4    5    6    7    8    9  ...  775  776  777  778  \\\n",
       "0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0  0.0   \n",
       "1  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0  0.0   \n",
       "2  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0  0.0   \n",
       "3  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0  0.0   \n",
       "4  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0  0.0   \n",
       "5  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0  0.0   \n",
       "6  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0  0.0   \n",
       "7  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0  0.0   \n",
       "8  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0  0.0   \n",
       "9  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0  0.0   \n",
       "\n",
       "   779  780  781  782  783                                              preds  \n",
       "0  0.0  0.0  0.0  0.0  0.0  [-6.9618006, 1.2047814, -0.09570807, 0.0462105...  \n",
       "1  0.0  0.0  0.0  0.0  0.0  [-5.2882323, 5.902014, -2.0389183, -1.2460864,...  \n",
       "2  0.0  0.0  0.0  0.0  0.0  [-5.822013, -2.3333628, -2.4322102, -8.040086,...  \n",
       "3  0.0  0.0  0.0  0.0  0.0  [-0.57203317, -1.2920653, -2.7234774, 0.914070...  \n",
       "4  0.0  0.0  0.0  0.0  0.0  [-3.689301, 5.0702505, -0.23930073, -0.7988689...  \n",
       "5  0.0  0.0  0.0  0.0  0.0  [8.268821, -2.070008, 1.722378, -1.8471404, -8...  \n",
       "6  0.0  0.0  0.0  0.0  0.0  [5.59269, -3.1613479, 0.4734843, -0.7772096, -...  \n",
       "7  0.0  0.0  0.0  0.0  0.0  [1.9852623, -5.166985, 0.86473066, -6.491789, ...  \n",
       "8  0.0  0.0  0.0  0.0  0.0  [-2.800528, -4.2984514, 10.887824, -3.1346364,...  \n",
       "9  0.0  0.0  0.0  0.0  0.0  [-3.7827752, -4.51145, -5.354035, 9.399383, -6...  \n",
       "\n",
       "[10 rows x 785 columns]"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "preds = df.withColumn(\"preds\", mnist(struct(df.columns))).limit(10).toPandas()\n",
    "preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "7c067c62-03a6-461e-a1ff-4653276fbea1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "a7084ad0-c021-4296-bad0-7a238971f53b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-6.9618006 ,  1.2047814 , -0.09570807,  0.04621054, -5.8169513 ,\n",
       "       -4.148872  , -5.17938   ,  6.382909  , -0.11228667,  0.6022302 ],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sample = preds.iloc[0]\n",
    "sample.preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "8167c832-93ef-4f50-873b-07b67c19ef53",
   "metadata": {},
   "outputs": [],
   "source": [
    "prediction = np.argmax(sample.preds)\n",
    "img = sample.drop('preds').to_numpy(dtype=float)\n",
    "img = np.array(img).reshape(28,28)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "297811e1-aecb-4afd-9a6a-30c49e8881cc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGzCAYAAABpdMNsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAiTklEQVR4nO3dfXBV9b3v8c/O0+YpCYQ8S8CAAhYET1FyuCCipAlBHVF6KmrvBY4FpQHFHGsPTgVRZtJDTzmoTcE59xTaUxAP0wK3lKKAJBQKdEAYBqu5kMYCAwnImAQChIf9u39w2cdNArg2O3zz8H7NrJnstX7ftb5ZLPiw9lp7bZ9zzgkAgFssyroBAED7RAABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEe3H777Zo0aVLwdWlpqXw+n0pLSyO2DZ/Pp9dffz1i6wNaKgIIrcbSpUvl8/mCU4cOHdS3b19Nnz5d1dXV1u15sm7dulYTMl/d51dP3/rWt6zbQysWY90A4NUbb7yh7OxsnTt3Tlu3btWiRYu0bt067d+/X506dbqlvYwcOVJnz55VXFycp7p169appKSkyRA6e/asYmJazl/N//zP/2w0b9euXXrrrbeUl5dn0BHaipZzlANfU0FBge69915J0ve+9z11795dCxYs0Jo1a/TUU081WVNfX6/OnTtHvJeoqCh16NAhouuM9Ppu1ne/+91G86689Xit/Q18HbwFh1bvoYcekiRVVlZKkiZNmqQuXbqooqJCY8eOVXx8vJ555hlJUiAQ0MKFCzVgwAB16NBBaWlpeu655/Tll1+GrNM5p3nz5qlHjx7q1KmTHnzwQX3yySeNtn2ta0A7d+7U2LFj1a1bN3Xu3FmDBg3SW2+9FeyvpKREUujbW1c0dQ1oz549KigoUEJCgrp06aLRo0drx44dIWOuvEW5bds2FRUVKSUlRZ07d9bjjz+uEydOhIytra3VZ599ptra2q+zi0M0NDToN7/5jR544AH16NHDcz1wBWdAaPUqKiokSd27dw/Ou3jxovLz8zVixAj967/+a/Ctueeee05Lly7V5MmT9cILL6iyslI/+9nPtGfPHm3btk2xsbGSpNmzZ2vevHkaO3asxo4dq48//lh5eXk6f/78DfvZsGGDHnnkEWVkZOjFF19Uenq6Pv30U61du1YvvviinnvuOR09elQbNmxo8u2tq33yySe6//77lZCQoFdeeUWxsbF69913NWrUKJWVlSknJydk/IwZM9StWzfNmTNHn3/+uRYuXKjp06fr/fffD45ZtWqVJk+erCVLloTcVPF1rFu3TjU1NcFQB8LmgFZiyZIlTpLbuHGjO3HihDt8+LBbsWKF6969u+vYsaM7cuSIc865iRMnOknun//5n0Pq//jHPzpJbtmyZSHz169fHzL/+PHjLi4uzj388MMuEAgEx7366qtOkps4cWJw3ubNm50kt3nzZueccxcvXnTZ2dmuV69e7ssvvwzZzlfXVVhY6K7110+SmzNnTvD1uHHjXFxcnKuoqAjOO3r0qIuPj3cjR45stH9yc3NDtvXSSy+56OhoV1NT02jskiVLmuzhesaPH+/8fn+j3w/wirfg0Ork5uYqJSVFWVlZmjBhgrp06aJVq1bptttuCxk3bdq0kNcrV65UYmKivvWtb+mLL74ITkOGDFGXLl20efNmSdLGjRt1/vx5zZgxI+StsZkzZ96wtz179qiyslIzZ85U165dQ5Z9dV1f16VLl/Thhx9q3Lhx6t27d3B+RkaGnn76aW3dulV1dXUhNVOnTg3Z1v33369Lly7pb3/7W3DepEmT5JzzfPZTV1en3//+9xo7dmyj3w/wirfg0OqUlJSob9++iomJUVpamvr166eoqND/S8XExDS6PnHgwAHV1tYqNTW1yfUeP35ckoL/UN95550hy1NSUtStW7fr9nbl7cCBAwd+/V/oOk6cOKEzZ86oX79+jZbdddddCgQCOnz4sAYMGBCc37Nnz5BxV3q++jpXOH7zm9/o3LlzvP2GiCCA0OoMHTo0eBfctfj9/kahFAgElJqaqmXLljVZk5KSErEeLUVHRzc53zl30+tetmyZEhMT9cgjj9z0ugACCO1Gnz59tHHjRg0fPlwdO3a85rhevXpJunzG9NW3vU6cOHHDs4g+ffpIkvbv36/c3Nxrjvu6b8elpKSoU6dOKi8vb7Tss88+U1RUlLKysr7Wum7WsWPHtHnzZk2aNEl+v/+WbBNtG9eA0G585zvf0aVLl/Tmm282Wnbx4kXV1NRIunyNKTY2Vu+8807IWcPChQtvuI1vfvObys7O1sKFC4Pru+Kr67rymaSrx1wtOjpaeXl5WrNmjT7//PPg/Orqai1fvlwjRoxQQkLCDfu6Wji3Ya9YsUKBQIC33xAxnAGh3XjggQf03HPPqbi4WHv37lVeXp5iY2N14MABrVy5Um+99Za+/e1vKyUlRS+//LKKi4v1yCOPaOzYsdqzZ4/+8Ic/KDk5+brbiIqK0qJFi/Too4/qnnvu0eTJk5WRkaHPPvtMn3zyiT744ANJ0pAhQyRJL7zwgvLz8xUdHa0JEyY0uc558+Zpw4YNGjFihL7//e8rJiZG7777rhoaGjR//vyw9kU4t2EvW7ZMmZmZGjVqVFjbBK5GAKFdWbx4sYYMGaJ3331Xr776qmJiYnT77bfru9/9roYPHx4cN2/ePHXo0EGLFy/W5s2blZOTow8//FAPP/zwDbeRn5+vzZs3a+7cufrpT3+qQCCgPn36aMqUKcExTzzxhGbMmKEVK1bo17/+tZxz1wygAQMG6I9//KNmzZql4uJiBQIB5eTk6Ne//nWjzwA1l/Lycu3evVtFRUWNrq0B4fK5SFyZBADAI/4rAwAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMtLjPAQUCAR09elTx8fFhPT0YAGDLOadTp04pMzPzup8ba3EBdPTo0Vv2bCsAQPM5fPjwdb81t8UFUHx8vCRphMYqRrHG3QAAvLqoC9qqdcF/z6+l2QKopKREP/nJT1RVVaXBgwfrnXfe0dChQ29Yd+VttxjFKsZHAAFAq/P/n69zo8sozXITwvvvv6+ioiLNmTNHH3/8sQYPHqz8/PzgF34BANAsAbRgwQJNmTJFkydP1je+8Q0tXrxYnTp10i9+8Yvm2BwAoBWKeACdP39eu3fvDvkyrqioKOXm5mr79u2Nxjc0NKiuri5kAgC0fREPoC+++EKXLl1SWlpayPy0tDRVVVU1Gl9cXKzExMTgxB1wANA+mH8QddasWaqtrQ1Ohw8ftm4JAHALRPwuuOTkZEVHR6u6ujpkfnV1tdLT0xuN9/v9fL88ALRDET8DiouL05AhQ7Rp06bgvEAgoE2bNmnYsGGR3hwAoJVqls8BFRUVaeLEibr33ns1dOhQLVy4UPX19Zo8eXJzbA4A0Ao1SwA9+eSTOnHihGbPnq2qqirdc889Wr9+faMbEwAA7ZfPOeesm/iquro6JSYmapQe40kIANAKXXQXVKo1qq2tVUJCwjXHmd8FBwBonwggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGAi4gH0+uuvy+fzhUz9+/eP9GYAAK1cTHOsdMCAAdq4ceN/bySmWTYDAGjFmiUZYmJilJ6e3hyrBgC0Ec1yDejAgQPKzMxU79699cwzz+jQoUPXHNvQ0KC6urqQCQDQ9kU8gHJycrR06VKtX79eixYtUmVlpe6//36dOnWqyfHFxcVKTEwMTllZWZFuCQDQAvmcc645N1BTU6NevXppwYIFevbZZxstb2hoUENDQ/B1XV2dsrKyNEqPKcYX25ytAQCawUV3QaVao9raWiUkJFxzXLPfHdC1a1f17dtXBw8ebHK53++X3+9v7jYAAC1Ms38O6PTp06qoqFBGRkZzbwoA0IpEPIBefvlllZWV6fPPP9ef/vQnPf7444qOjtZTTz0V6U0BAFqxiL8Fd+TIET311FM6efKkUlJSNGLECO3YsUMpKSmR3hQAoBWLeACtWLEi0qtEOxc9oJ/nmpqB3cLa1qkJ3j8G8D9uq/Rcs+1Ib881w3v81XPN1lV/57lGknq+tddzTeDMmbC2hfaLZ8EBAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAw0exfSAd8VfQd2Z5rpq7+veeahzvVeq6RpCj5PNcEFMaXCt+21XtNGKKmbwurrl9SoeeaPj/YHta20H5xBgQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMMHTsHFLueovPNcU/eEZzzUPj/+55xpJ+jJw1nPNfRtf8FwTdyTOc83+f/yZ55pw/fzx/+255q238z3XXDx8xHMN2g7OgAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJjgYaS4pQKnTnmu6f/mXz3X3HPb//JcI0kd1yd4run779s918Rk9/Jco3/0XhKu1OjTnmtcpw7N0AnaMs6AAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmOBhpGjxLp044bmmx3jvNbdSQ6/unmui5GuGTq6xLZ+7ZdtC+8UZEADABAEEADDhOYC2bNmiRx99VJmZmfL5fFq9enXIcuecZs+erYyMDHXs2FG5ubk6cOBApPoFALQRngOovr5egwcPVklJSZPL58+fr7fffluLFy/Wzp071blzZ+Xn5+vcuXM33SwAoO3wfBNCQUGBCgoKmlzmnNPChQv1ox/9SI899pgk6Ve/+pXS0tK0evVqTZgw4ea6BQC0GRG9BlRZWamqqirl5uYG5yUmJionJ0fbtzf9tcUNDQ2qq6sLmQAAbV9EA6iqqkqSlJaWFjI/LS0tuOxqxcXFSkxMDE5ZWVmRbAkA0EKZ3wU3a9Ys1dbWBqfDhw9btwQAuAUiGkDp6emSpOrq6pD51dXVwWVX8/v9SkhICJkAAG1fRAMoOztb6enp2rRpU3BeXV2ddu7cqWHDhkVyUwCAVs7zXXCnT5/WwYMHg68rKyu1d+9eJSUlqWfPnpo5c6bmzZunO++8U9nZ2XrttdeUmZmpcePGRbJvAEAr5zmAdu3apQcffDD4uqioSJI0ceJELV26VK+88orq6+s1depU1dTUaMSIEVq/fr06dOgQua4BAK2e5wAaNWqUnLv2gwp9Pp/eeOMNvfHGGzfVGNCWHc71e64JyPsDQsN9gGlS1EXPNYEu3n8ntG/md8EBANonAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJz0/DBnDzfH1PW7dwXfOPP3jjQVdxuz9phk7QlnEGBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQPIwVu0ul/yPFc839yFoSxpQ5h1ITngz/c67nmdm1vhk7QlnEGBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQPIwVu0tHcgOeaPjEdm6GTyMncdtG6BbQDnAEBAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwwcNIga+I7p7kueahwZ96rgnIea4JR9/fPx9e3YcfR7gToDHOgAAAJgggAIAJzwG0ZcsWPfroo8rMzJTP59Pq1atDlk+aNEk+ny9kGjNmTKT6BQC0EZ4DqL6+XoMHD1ZJSck1x4wZM0bHjh0LTu+9995NNQkAaHs834RQUFCggoKC647x+/1KT08PuykAQNvXLNeASktLlZqaqn79+mnatGk6efLkNcc2NDSorq4uZAIAtH0RD6AxY8boV7/6lTZt2qR/+Zd/UVlZmQoKCnTp0qUmxxcXFysxMTE4ZWVlRbolAEALFPHPAU2YMCH48913361BgwapT58+Ki0t1ejRoxuNnzVrloqKioKv6+rqCCEAaAea/Tbs3r17Kzk5WQcPHmxyud/vV0JCQsgEAGj7mj2Ajhw5opMnTyojI6O5NwUAaEU8vwV3+vTpkLOZyspK7d27V0lJSUpKStLcuXM1fvx4paenq6KiQq+88oruuOMO5efnR7RxAEDr5jmAdu3apQcffDD4+sr1m4kTJ2rRokXat2+ffvnLX6qmpkaZmZnKy8vTm2++Kb/fH7muAQCtnucAGjVqlJy79oMUP/jgg5tqCLBUOaO/55o1We80QyeR8Y3Zh8Kquxho+q5VIJJ4FhwAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwETEv5IbaM1Gjt1j3cI13VX6Pc81fapa7u8DcAYEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABA8jBb7i57dtC6PK57ni/14457mm32tfeq656LkCuHU4AwIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCh5GiTTr9DzlhVn7suSIg57nmO3u+57km869/8VwDtGScAQEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADDBw0jR4kV3TfRc8z/nrm2GTiIn/adx1i0A5jgDAgCYIIAAACY8BVBxcbHuu+8+xcfHKzU1VePGjVN5eXnImHPnzqmwsFDdu3dXly5dNH78eFVXV0e0aQBA6+cpgMrKylRYWKgdO3Zow4YNunDhgvLy8lRfXx8c89JLL+l3v/udVq5cqbKyMh09elRPPPFExBsHALRunm5CWL9+fcjrpUuXKjU1Vbt379bIkSNVW1ur//iP/9Dy5cv10EMPSZKWLFmiu+66Szt27NDf//3fR65zAECrdlPXgGprayVJSUlJkqTdu3frwoULys3NDY7p37+/evbsqe3btze5joaGBtXV1YVMAIC2L+wACgQCmjlzpoYPH66BAwdKkqqqqhQXF6euXbuGjE1LS1NVVVWT6ykuLlZiYmJwysrKCrclAEArEnYAFRYWav/+/VqxYsVNNTBr1izV1tYGp8OHD9/U+gAArUNYH0SdPn261q5dqy1btqhHjx7B+enp6Tp//rxqampCzoKqq6uVnp7e5Lr8fr/8fn84bQAAWjFPZ0DOOU2fPl2rVq3SRx99pOzs7JDlQ4YMUWxsrDZt2hScV15erkOHDmnYsGGR6RgA0CZ4OgMqLCzU8uXLtWbNGsXHxwev6yQmJqpjx45KTEzUs88+q6KiIiUlJSkhIUEzZszQsGHDuAMOABDCUwAtWrRIkjRq1KiQ+UuWLNGkSZMkSf/2b/+mqKgojR8/Xg0NDcrPz9fPf/7ziDQLAGg7PAWQc+6GYzp06KCSkhKVlJSE3RTwVb5uXT3XPJt4KNythVkHwCueBQcAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMBHWN6ICLV1UmE+1jvaF8X8yFwhrW0B7xxkQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEzyMFC1e5TO3ea4JyIW3sTAeLJr36TjPNbE7/+K5JszfCGixOAMCAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABggoeRosVL3n/Rc83imt5hbevb8Z94rhmZctBzzZ8uxHmuAdoazoAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCY4GGkaPE6rv6z55r1+/4urG0t+EG+55r4g97/GmXoT55rgLaGMyAAgAkCCABgwlMAFRcX67777lN8fLxSU1M1btw4lZeXh4wZNWqUfD5fyPT8889HtGkAQOvnKYDKyspUWFioHTt2aMOGDbpw4YLy8vJUX18fMm7KlCk6duxYcJo/f35EmwYAtH6erp6uX78+5PXSpUuVmpqq3bt3a+TIkcH5nTp1Unp6emQ6BAC0STd1Dai2tlaSlJSUFDJ/2bJlSk5O1sCBAzVr1iydOXPmmutoaGhQXV1dyAQAaPvCvg07EAho5syZGj58uAYOHBic//TTT6tXr17KzMzUvn379MMf/lDl5eX67W9/2+R6iouLNXfu3HDbAAC0UmEHUGFhofbv36+tW7eGzJ86dWrw57vvvlsZGRkaPXq0Kioq1KdPn0brmTVrloqKioKv6+rqlJWVFW5bAIBWIqwAmj59utauXastW7aoR48e1x2bk5MjSTp48GCTAeT3++X3+8NpAwDQinkKIOecZsyYoVWrVqm0tFTZ2dk3rNm7d68kKSMjI6wGAQBtk6cAKiws1PLly7VmzRrFx8erqqpKkpSYmKiOHTuqoqJCy5cv19ixY9W9e3ft27dPL730kkaOHKlBgwY1yy8AAGidPAXQokWLJF3+sOlXLVmyRJMmTVJcXJw2btyohQsXqr6+XllZWRo/frx+9KMfRaxhAEDb4PktuOvJyspSWVnZTTUEAGgfeBo22qSLf/08rLq+08KrA+AdDyMFAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgIsa6gas55yRJF3VBcsbNAAA8u6gLkv773/NraXEBdOrUKUnSVq0z7gQAcDNOnTqlxMTEay73uRtF1C0WCAR09OhRxcfHy+fzhSyrq6tTVlaWDh8+rISEBKMO7bEfLmM/XMZ+uIz9cFlL2A/OOZ06dUqZmZmKirr2lZ4WdwYUFRWlHj16XHdMQkJCuz7ArmA/XMZ+uIz9cBn74TLr/XC9M58ruAkBAGCCAAIAmGhVAeT3+zVnzhz5/X7rVkyxHy5jP1zGfriM/XBZa9oPLe4mBABA+9CqzoAAAG0HAQQAMEEAAQBMEEAAABMEEADARKsJoJKSEt1+++3q0KGDcnJy9Oc//9m6pVvu9ddfl8/nC5n69+9v3Vaz27Jlix599FFlZmbK5/Np9erVIcudc5o9e7YyMjLUsWNH5ebm6sCBAzbNNqMb7YdJkyY1Oj7GjBlj02wzKS4u1n333af4+HilpqZq3LhxKi8vDxlz7tw5FRYWqnv37urSpYvGjx+v6upqo46bx9fZD6NGjWp0PDz//PNGHTetVQTQ+++/r6KiIs2ZM0cff/yxBg8erPz8fB0/fty6tVtuwIABOnbsWHDaunWrdUvNrr6+XoMHD1ZJSUmTy+fPn6+3335bixcv1s6dO9W5c2fl5+fr3Llzt7jT5nWj/SBJY8aMCTk+3nvvvVvYYfMrKytTYWGhduzYoQ0bNujChQvKy8tTfX19cMxLL72k3/3ud1q5cqXKysp09OhRPfHEE4ZdR97X2Q+SNGXKlJDjYf78+UYdX4NrBYYOHeoKCwuDry9duuQyMzNdcXGxYVe33pw5c9zgwYOt2zAlya1atSr4OhAIuPT0dPeTn/wkOK+mpsb5/X733nvvGXR4a1y9H5xzbuLEie6xxx4z6cfK8ePHnSRXVlbmnLv8Zx8bG+tWrlwZHPPpp586SW779u1WbTa7q/eDc8498MAD7sUXX7Rr6mto8WdA58+f1+7du5WbmxucFxUVpdzcXG3fvt2wMxsHDhxQZmamevfurWeeeUaHDh2ybslUZWWlqqqqQo6PxMRE5eTktMvjo7S0VKmpqerXr5+mTZumkydPWrfUrGprayVJSUlJkqTdu3frwoULIcdD//791bNnzzZ9PFy9H65YtmyZkpOTNXDgQM2aNUtnzpyxaO+aWtzTsK/2xRdf6NKlS0pLSwuZn5aWps8++8yoKxs5OTlaunSp+vXrp2PHjmnu3Lm6//77tX//fsXHx1u3Z6KqqkqSmjw+rixrL8aMGaMnnnhC2dnZqqio0KuvvqqCggJt375d0dHR1u1FXCAQ0MyZMzV8+HANHDhQ0uXjIS4uTl27dg0Z25aPh6b2gyQ9/fTT6tWrlzIzM7Vv3z798Ic/VHl5uX77298adhuqxQcQ/ltBQUHw50GDBiknJ0e9evXSf/3Xf+nZZ5817AwtwYQJE4I/33333Ro0aJD69Omj0tJSjR492rCz5lFYWKj9+/e3i+ug13Ot/TB16tTgz3fffbcyMjI0evRoVVRUqE+fPre6zSa1+LfgkpOTFR0d3egulurqaqWnpxt11TJ07dpVffv21cGDB61bMXPlGOD4aKx3795KTk5uk8fH9OnTtXbtWm3evDnk+8PS09N1/vx51dTUhIxvq8fDtfZDU3JyciSpRR0PLT6A4uLiNGTIEG3atCk4LxAIaNOmTRo2bJhhZ/ZOnz6tiooKZWRkWLdiJjs7W+np6SHHR11dnXbu3Nnuj48jR47o5MmTber4cM5p+vTpWrVqlT766CNlZ2eHLB8yZIhiY2NDjofy8nIdOnSoTR0PN9oPTdm7d68ktazjwfouiK9jxYoVzu/3u6VLl7q//OUvburUqa5r166uqqrKurVb6p/+6Z9caWmpq6ysdNu2bXO5ubkuOTnZHT9+3Lq1ZnXq1Cm3Z88et2fPHifJLViwwO3Zs8f97W9/c8459+Mf/9h17drVrVmzxu3bt8899thjLjs72509e9a488i63n44deqUe/nll9327dtdZWWl27hxo/vmN7/p7rzzTnfu3Dnr1iNm2rRpLjEx0ZWWlrpjx44FpzNnzgTHPP/8865nz57uo48+crt27XLDhg1zw4YNM+w68m60Hw4ePOjeeOMNt2vXLldZWenWrFnjevfu7UaOHGnceahWEUDOOffOO++4nj17uri4ODd06FC3Y8cO65ZuuSeffNJlZGS4uLg4d9ttt7knn3zSHTx40LqtZrd582YnqdE0ceJE59zlW7Ffe+01l5aW5vx+vxs9erQrLy+3bboZXG8/nDlzxuXl5bmUlBQXGxvrevXq5aZMmdLm/pPW1O8vyS1ZsiQ45uzZs+773/++69atm+vUqZN7/PHH3bFjx+yabgY32g+HDh1yI0eOdElJSc7v97s77rjD/eAHP3C1tbW2jV+F7wMCAJho8deAAABtEwEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBM/D/AaY3Zb7z6aAAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure()\n",
    "plt.title(\"Prediction: {}\".format(prediction))\n",
    "plt.imshow(img)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3dc87a7",
   "metadata": {},
   "source": [
    "## Using Triton Inference Server\n",
    "In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL.  \n",
    "We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server.  \n",
    "\n",
    "The process looks like this:\n",
    "- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\n",
    "- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\n",
    "- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\n",
    "- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\n",
    "\n",
    "<img src=\"../images/spark-server.png\" alt=\"drawing\" width=\"700\"/>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "cfc841c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import partial"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d1e63867",
   "metadata": {},
   "source": [
    "Import the helper class from server_utils.py:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "d7af3599",
   "metadata": {},
   "outputs": [],
   "source": [
    "sc.addPyFile(\"server_utils.py\")\n",
    "\n",
    "from server_utils import TritonServerManager"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32cbe1cb",
   "metadata": {},
   "source": [
    "Define the Triton Server function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "c3539d1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def triton_server(ports, model_path):\n",
    "    import time\n",
    "    import signal\n",
    "    import numpy as np\n",
    "    from pytriton.decorators import batch\n",
    "    from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\n",
    "    from pytriton.triton import Triton, TritonConfig\n",
    "    from pyspark import TaskContext\n",
    "    import tensorflow as tf\n",
    "    from tensorflow import keras\n",
    "\n",
    "    print(f\"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.\")\n",
    "\n",
    "    # Enable GPU memory growth\n",
    "    gpus = tf.config.experimental.list_physical_devices('GPU')\n",
    "    if gpus:\n",
    "        try:\n",
    "            for gpu in gpus:\n",
    "                tf.config.experimental.set_memory_growth(gpu, True)\n",
    "        except RuntimeError as e:\n",
    "            print(e)\n",
    "\n",
    "    model = keras.models.load_model(model_path)\n",
    "\n",
    "    @batch\n",
    "    def _infer_fn(**inputs):\n",
    "        images = np.squeeze(inputs[\"images\"])\n",
    "        print(f\"SERVER: Received batch of size {len(images)}.\")\n",
    "        return {\n",
    "            \"labels\": model.predict(images)\n",
    "        }\n",
    "\n",
    "    workspace_path = f\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\"\n",
    "    triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\n",
    "    with Triton(config=triton_conf, workspace=workspace_path) as triton:\n",
    "        triton.bind(\n",
    "            model_name=\"ImageClassifier\",\n",
    "            infer_func=_infer_fn,\n",
    "            inputs=[\n",
    "                Tensor(name=\"images\", dtype=np.float64, shape=(-1,)),\n",
    "            ],\n",
    "            outputs=[\n",
    "                Tensor(name=\"labels\", dtype=np.float32, shape=(-1,)),\n",
    "            ],\n",
    "            config=ModelConfig(\n",
    "                max_batch_size=128,\n",
    "                batcher=DynamicBatcher(max_queue_delay_microseconds=5000),  # 5ms\n",
    "            ),\n",
    "            strict=True,\n",
    "        )\n",
    "\n",
    "        def _stop_triton(signum, frame):\n",
    "            # The server manager sends SIGTERM to stop the server; this function ensures graceful cleanup.\n",
    "            print(\"SERVER: Received SIGTERM. Stopping Triton server.\")\n",
    "            triton.stop()\n",
    "\n",
    "        signal.signal(signal.SIGTERM, _stop_triton)\n",
    "\n",
    "        print(\"SERVER: Serving inference\")\n",
    "        triton.serve()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce4c7701",
   "metadata": {},
   "source": [
    "#### Start Triton servers"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2695d9ab",
   "metadata": {},
   "source": [
    "The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:\n",
    "- Find available ports for HTTP/gRPC/metrics\n",
    "- Deploy a server on each node via stage-level scheduling\n",
    "- Gracefully shutdown servers across nodes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4deae3b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = \"ImageClassifier\"\n",
    "server_manager = TritonServerManager(model_name=model_name, model_path=model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e56c84f4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-02-07 11:03:44,809 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n",
      "2025-02-07 11:03:44,810 - INFO - Starting 1 servers.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'cb4ae00-lcedt': (2020631, [7000, 7001, 7002])}"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}\n",
    "server_manager.start_servers(triton_server)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "77847814",
   "metadata": {},
   "source": [
    "#### Define client function"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e278fde0",
   "metadata": {},
   "source": [
    "Get the hostname -> url mapping from the server manager:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68a9606e",
   "metadata": {},
   "outputs": [],
   "source": [
    "host_to_http_url = server_manager.host_to_http_url  # or server_manager.host_to_grpc_url"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d70bd6f",
   "metadata": {},
   "source": [
    "Define the Triton inference function, which returns a predict function for batch inference through the server:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "92ba2e26",
   "metadata": {},
   "outputs": [],
   "source": [
    "def triton_fn(model_name, host_to_url):\n",
    "    import socket\n",
    "    from pytriton.client import ModelClient\n",
    "\n",
    "    url = host_to_url[socket.gethostname()]\n",
    "    print(f\"Connecting to Triton model {model_name} at {url}.\")\n",
    "\n",
    "    def infer_batch(inputs):\n",
    "        with ModelClient(url, model_name, inference_timeout_s=240) as client:\n",
    "            result_data = client.infer_batch(inputs)\n",
    "            return result_data[\"labels\"]\n",
    "        \n",
    "    return infer_batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "6658d2a1-ef7b-4ca1-9fb6-f2ac9050f3e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "predict = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_http_url),\n",
    "                            input_tensor_shapes=[[784]],\n",
    "                            return_type=ArrayType(FloatType()),\n",
    "                            batch_size=128)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3842c263",
   "metadata": {},
   "source": [
    "#### Run inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "43b93753-1d52-4060-9986-f24c30a67528",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "StructType([StructField('data', ArrayType(DoubleType(), True), True)])"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = spark.read.parquet(data_path_1)\n",
    "df.schema"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "8397aa14-82fd-4351-a477-dc8e8b321fa2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Stage 19:>                                                         (0 + 8) / 8]\r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 19.8 ms, sys: 2.89 ms, total: 22.7 ms\n",
      "Wall time: 1.67 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "preds = df.withColumn(\"preds\", predict(struct(\"data\"))).collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "82698bd9-377a-4415-8971-835487f876cc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 19.8 ms, sys: 5.99 ms, total: 25.7 ms\n",
      "Wall time: 399 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "preds = df.withColumn(\"preds\", predict(\"data\")).collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "419ad7bd-fa28-49d3-b98d-db9fba5aeaef",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Stage 21:====================================>                     (5 + 3) / 8]\r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 9.07 ms, sys: 1.34 ms, total: 10.4 ms\n",
      "Wall time: 888 ms\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>data</th>\n",
       "      <th>preds</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\n",
       "      <td>[-4.6654444, -2.4893682, -0.5888205, 13.380681...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\n",
       "      <td>[-2.2732146, -7.5127845, 1.1983705, -3.540661,...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\n",
       "      <td>[-2.2890894, 0.8308606, 0.31311002, 1.1683631,...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\n",
       "      <td>[-1.055197, -6.502811, 12.420727, 0.4528031, -...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\n",
       "      <td>[-3.7887795, 3.9983597, -1.5343359, -0.3698441...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\n",
       "      <td>[-4.4992743, -1.7618219, 1.1183226, 3.9469318,...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\n",
       "      <td>[-2.754053, 4.868414, 0.2515293, -0.47300792, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\n",
       "      <td>[-1.888711, 0.02717158, -6.050885, 0.08750934,...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\n",
       "      <td>[0.9541264, -2.113048, -1.7508973, -5.4303784,...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\n",
       "      <td>[-1.612412, -0.7655782, -4.473859, 2.0609212, ...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                data  \\\n",
       "0  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \n",
       "1  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \n",
       "2  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \n",
       "3  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \n",
       "4  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \n",
       "5  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \n",
       "6  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \n",
       "7  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \n",
       "8  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \n",
       "9  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   \n",
       "\n",
       "                                               preds  \n",
       "0  [-4.6654444, -2.4893682, -0.5888205, 13.380681...  \n",
       "1  [-2.2732146, -7.5127845, 1.1983705, -3.540661,...  \n",
       "2  [-2.2890894, 0.8308606, 0.31311002, 1.1683631,...  \n",
       "3  [-1.055197, -6.502811, 12.420727, 0.4528031, -...  \n",
       "4  [-3.7887795, 3.9983597, -1.5343359, -0.3698441...  \n",
       "5  [-4.4992743, -1.7618219, 1.1183226, 3.9469318,...  \n",
       "6  [-2.754053, 4.868414, 0.2515293, -0.47300792, ...  \n",
       "7  [-1.888711, 0.02717158, -6.050885, 0.08750934,...  \n",
       "8  [0.9541264, -2.113048, -1.7508973, -5.4303784,...  \n",
       "9  [-1.612412, -0.7655782, -4.473859, 2.0609212, ...  "
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%time\n",
    "preds = df.withColumn(\"preds\", predict(col(\"data\"))).limit(10).toPandas()\n",
    "preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "79d90a26",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "4ca495f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample = preds.iloc[0]\n",
    "sample.preds\n",
    "\n",
    "prediction = np.argmax(sample.preds)\n",
    "img = np.array(sample.data).reshape(28,28)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "a5d10903",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGzCAYAAABpdMNsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAkfElEQVR4nO3dfXQUdZ7v8U/nqSEkaR7yLAFCFHRAcAYly/AgSiQEZUCYGUG9F7gziJiggI6KR0Udzsksrg7qIHjcHVhHEGWOyMoiDg9JGBRwwTCIM2QhJ0g4kIBcSYcAIaR/9w+uvbQkQDUdfkl4v86pc+iq37fqm6Lgk+qqrnYZY4wAALjKwmw3AAC4NhFAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAgAPdunXTpEmT/K8LCwvlcrlUWFgYsm24XC698MILIVsf0FwRQGgxlixZIpfL5Z/atGmjHj16KC8vT5WVlbbbc2TNmjUtJmTefvtt3X777UpKSpLb7VZ6eromT56s/fv3224NLVyE7QYAp1566SWlp6fr9OnT2rx5sxYuXKg1a9Zo9+7dio6Ovqq9DBkyRKdOnVJUVJSjujVr1mjBggUNhtCpU6cUEdF8/mkWFxcrPT1dP/vZz9ShQweVlZXp7bff1urVq/W3v/1NqamptltEC9V8jnLgMuXk5OjWW2+VJP36179Wp06d9Oqrr2rVqlWaMGFCgzU1NTVq165dyHsJCwtTmzZtQrrOUK/vSr355psXzBszZoxuvfVWvfPOO3r66actdIXWgLfg0OLdeeedkqSysjJJ0qRJkxQTE6PS0lKNHDlSsbGxeuCBByRJPp9P8+fPV69evdSmTRslJSVp6tSp+u677wLWaYzR3Llz1blzZ0VHR+uOO+7Q119/fcG2G7sGtG3bNo0cOVIdOnRQu3bt1KdPH7322mv+/hYsWCBJAW8pfq+ha0DFxcXKyclRXFycYmJiNGzYMG3dujVgzPdvUX722WeaNWuWEhIS1K5dO9177706evRowNiqqirt2bNHVVVVl7OLL9CtWzdJ0vHjx4OqByTOgNAKlJaWSpI6derkn3f27FllZ2dr0KBB+pd/+Rf/W3NTp07VkiVLNHnyZD366KMqKyvTH/7wBxUXF+uzzz5TZGSkJOn555/X3LlzNXLkSI0cOVJffvmlhg8frjNnzlyyn3Xr1umee+5RSkqKHnvsMSUnJ+sf//iHVq9erccee0xTp07VoUOHtG7dOv3pT3+65Pq+/vprDR48WHFxcXryyScVGRmpt956S0OHDlVRUZEyMzMDxk+fPl0dOnTQnDlztH//fs2fP195eXl6//33/WNWrlypyZMna/HixQE3VVzMsWPHVF9frwMHDuill16SJA0bNuyyaoEGGaCFWLx4sZFk1q9fb44ePWrKy8vN8uXLTadOnUzbtm3NwYMHjTHGTJw40UgyTz/9dED9X//6VyPJLF26NGD+2rVrA+YfOXLEREVFmbvvvtv4fD7/uGeeecZIMhMnTvTPKygoMJJMQUGBMcaYs2fPmvT0dNO1a1fz3XffBWzn/HXl5uaaxv75STJz5szxvx4zZoyJiooypaWl/nmHDh0ysbGxZsiQIRfsn6ysrIBtzZw504SHh5vjx49fMHbx4sUN9tAQt9ttJBlJplOnTub111+/7FqgIbwFhxYnKytLCQkJSktL0/jx4xUTE6OVK1fquuuuCxg3bdq0gNcrVqyQx+PRXXfdpW+//dY/9evXTzExMSooKJAkrV+/XmfOnNH06dMD3hqbMWPGJXsrLi5WWVmZZsyYofbt2wcsO39dl6u+vl5/+ctfNGbMGHXv3t0/PyUlRffff782b94sr9cbUPPQQw8FbGvw4MGqr6/XN9984583adIkGWMu++xHkj755BOtWbNGr7zyirp06aKamhrHPw9wPt6CQ4uzYMEC9ejRQxEREUpKSlLPnj0VFhb4u1RERIQ6d+4cMG/v3r2qqqpSYmJig+s9cuSIJPn/o77hhhsClickJKhDhw4X7e37twN79+59+T/QRRw9elQnT55Uz549L1h20003yefzqby8XL169fLP79KlS8C473v+4XUup+644w5J524CGT16tHr37q2YmBjl5eVd0Xpx7SKA0OL079/ffxdcY9xu9wWh5PP5lJiYqKVLlzZYk5CQELIebQoPD29wvjEmZNvIyMjQj3/8Yy1dupQAQtAIIFwzMjIytH79eg0cOFBt27ZtdFzXrl0lnTtjOv9tr6NHj17yLCIjI0OStHv3bmVlZTU67nLfjktISFB0dLRKSkouWLZnzx6FhYUpLS3tstYVaqdOnVJtba2VbaN14BoQrhm//OUvVV9fr9/+9rcXLDt79qz/luKsrCxFRkbqjTfeCDhrmD9//iW38ZOf/ETp6emaP3/+Bbcon7+u7z+TdKnbmMPDwzV8+HCtWrUq4MkDlZWVWrZsmQYNGqS4uLhL9vVDl3sb9tmzZxsM3S+++EJfffXVJc9EgYvhDAjXjNtvv11Tp05Vfn6+du7cqeHDhysyMlJ79+7VihUr9Nprr+nnP/+5EhIS9MQTTyg/P1/33HOPRo4cqeLiYn3yySeKj4+/6DbCwsK0cOFCjRo1SrfccosmT56slJQU7dmzR19//bU+/fRTSVK/fv0kSY8++qiys7MVHh6u8ePHN7jOuXPnat26dRo0aJAeeeQRRURE6K233lJtba3mzZsX1L643NuwT5w4obS0NN13333q1auX2rVrp6+++kqLFy+Wx+PRc889F9T2AYkAwjVm0aJF6tevn9566y0988wzioiIULdu3fTggw9q4MCB/nFz585VmzZttGjRIhUUFCgzM1N/+ctfdPfdd19yG9nZ2SooKNCLL76oV155RT6fTxkZGZoyZYp/zNixYzV9+nQtX75c7777rowxjQZQr1699Ne//lWzZ89Wfn6+fD6fMjMz9e67717wGaBQi46O1q9//WsVFBToz3/+s06dOqXU1FRNmDBBzz77rP8DqUAwXCaUVyYBALhMXAMCAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMCKZvc5IJ/Pp0OHDik2NjaopwcDAOwyxqi6ulqpqakXPJPxfM0ugA4dOmTt2VYAgNApLy+/4Kn052t2ARQbGytJGqSRilCk5W4AAE6dVZ02a43///PGNFkALViwQC+//LIqKirUt29fvfHGG+rfv/8l675/2y1CkYpwEUAA0OL8/+frXOoySpPchPD+++9r1qxZmjNnjr788kv17dtX2dnZ/i/8AgCgSQLo1Vdf1ZQpUzR58mT96Ec/0qJFixQdHa0//vGPTbE5AEALFPIAOnPmjHbs2BHwZVxhYWHKysrSli1bLhhfW1srr9cbMAEAWr+QB9C3336r+vp6JSUlBcxPSkpSRUXFBePz8/Pl8Xj8E3fAAcC1wfoHUWfPnq2qqir/VF5ebrslAMBVEPK74OLj4xUeHq7KysqA+ZWVlUpOTr5gvNvtltvtDnUbAIBmLuRnQFFRUerXr582bNjgn+fz+bRhwwYNGDAg1JsDALRQTfI5oFmzZmnixIm69dZb1b9/f82fP181NTWaPHlyU2wOANACNUkA3XfffTp69Kief/55VVRU6JZbbtHatWsvuDEBAHDtchljjO0mzuf1euXxeDRUo3kSAgC0QGdNnQq1SlVVVYqLi2t0nPW74AAA1yYCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWBFhuwHYZ37aN7i6cOe/v0RWeh3XlP7vRMc1vu6nHNdI0p7b/+i4JtzlfD9MPTjAcU3hp7c4run2nzWOayRJW3cFVwc4wBkQAMAKAggAYEXIA+iFF16Qy+UKmG688cZQbwYA0MI1yTWgXr16af369f+zkQguNQEAAjVJMkRERCg5ObkpVg0AaCWa5BrQ3r17lZqaqu7du+uBBx7QgQMHGh1bW1srr9cbMAEAWr+QB1BmZqaWLFmitWvXauHChSorK9PgwYNVXV3d4Pj8/Hx5PB7/lJaWFuqWAADNUMgDKCcnR7/4xS/Up08fZWdna82aNTp+/Lg++OCDBsfPnj1bVVVV/qm8vDzULQEAmqEmvzugffv26tGjh/bt29fgcrfbLbfb3dRtAACamSb/HNCJEydUWlqqlJSUpt4UAKAFCXkAPfHEEyoqKtL+/fv1+eef695771V4eLgmTJgQ6k0BAFqwkL8Fd/DgQU2YMEHHjh1TQkKCBg0apK1btyohISHUmwIAtGAuY4yx3cT5vF6vPB6Phmq0IlyRttuxqubnmY5rKm91flK7dsLLjmskqUtEW8c1/2v/XY5r/tRtneManFN8xhdU3eOP5zmuif5wW1DbQutz1tSpUKtUVVWluLi4RsfxLDgAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIKHkV4lR/J+6rim8OlXHNdEu6Ic1zR339afclzTxhXc71Z1cv7PYcaBexzX/DLxvxzX3B1d5bgmWPvqah3XPDH4l45rzpYfdFyD5o+HkQIAmjUCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsiLDdwLXCF+68pjU+2frlYz9yXLNhxiDHNfVtg/vd6rvrnT+B/br/POy45s2EcY5r7v7zHx3XBGvsf011XNPt+P7QN4JWjTMgAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCh5FeJan/+jfHNR88kui4Jjv6gOOanDlPOK6RpLoYl+Oa6/7joOOaiP07nNc4rjgnOYia+iBqKu/5aRBVV8+uny5xXDMmiAes+qqrHdeg9eAMCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCs4GGkV4mvpsZxzTs90xzXvJ0z1nFNfEGx4xpJ8p0+7bjmbFBbunrCExIc13x3V4bjmscf/sBxDdDacAYEALCCAAIAWOE4gDZt2qRRo0YpNTVVLpdLH330UcByY4yef/55paSkqG3btsrKytLevXtD1S8AoJVwHEA1NTXq27evFixY0ODyefPm6fXXX9eiRYu0bds2tWvXTtnZ2TodxPUCAEDr5fgmhJycHOXk5DS4zBij+fPn69lnn9Xo0aMlSe+8846SkpL00Ucfafz48VfWLQCg1QjpNaCysjJVVFQoKyvLP8/j8SgzM1NbtmxpsKa2tlZerzdgAgC0fiENoIqKCklSUlJSwPykpCT/sh/Kz8+Xx+PxT2lpzm89BgC0PNbvgps9e7aqqqr8U3l5ue2WAABXQUgDKDk5WZJUWVkZML+ystK/7Ifcbrfi4uICJgBA6xfSAEpPT1dycrI2bNjgn+f1erVt2zYNGDAglJsCALRwju+CO3HihPbt2+d/XVZWpp07d6pjx47q0qWLZsyYoblz5+qGG25Qenq6nnvuOaWmpmrMmDGh7BsA0MI5DqDt27frjjvu8L+eNWuWJGnixIlasmSJnnzySdXU1Oihhx7S8ePHNWjQIK1du1Zt2rQJXdcAgBbPZYwxtps4n9frlcfj0VCNVoQr0nY7aKHC23uCqnt8x2bHNUPanAlqW1eDT76g6n73bV/HNduGpTiuqf/2mOMaNH9nTZ0KtUpVVVUXva5v/S44AMC1iQACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACscfx0D0BKUPdorqLohbTaGuBO7VtXEB1X3ed+oIKp4sjWc4QwIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKzgYaRAK3ZX28NB1c2d9YDjmrqYoDblWGLxWcc1bT7+ogk6wZXiDAgAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArOBhpGiV0jacDKpuxyTnNf3cQW3qqogJC665HY+/EeJOQmfOkR87rtnxMb9rN0f8rQAArCCAAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFTyMFK2S67OdQdW9NOp+xzW1STGOa6of9zqu+eyW5Y5rWqNnE7Y7rrnzwUeD2pbn3a1B1eHycAYEALCCAAIAWOE4gDZt2qRRo0YpNTVVLpdLH330UcDySZMmyeVyBUwjRowIVb8AgFbCcQDV1NSob9++WrBgQaNjRowYocOHD/un995774qaBAC0Po5vQsjJyVFOTs5Fx7jdbiUnJwfdFACg9WuSa0CFhYVKTExUz549NW3aNB07dqzRsbW1tfJ6vQETAKD1C3kAjRgxQu+88442bNigf/7nf1ZRUZFycnJUX1/f4Pj8/Hx5PB7/lJaWFuqWAADNUMg/BzR+/Hj/n2+++Wb16dNHGRkZKiws1LBhwy4YP3v2bM2aNcv/2uv1EkIAcA1o8tuwu3fvrvj4eO3bt6/B5W63W3FxcQETAKD1a/IAOnjwoI4dO6aUlJSm3hQAoAVx/BbciRMnAs5mysrKtHPnTnXs2FEdO3bUiy++qHHjxik5OVmlpaV68skndf311ys7OzukjQMAWjbHAbR9+3bdcccd/tffX7+ZOHGiFi5cqF27dunf//3fdfz4caWmpmr48OH67W9/K7fbHbquAQAtnuMAGjp0qIwxjS7/9NNPr6ghwKb6r0sc10R87Xw7HQpcjmtGRf3Ucc3+P/VwXCNJn2QudFzTOaJtUNtyKtIV7rjmdMfgrjZ4gqrC5eJZcAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALAi5F/JDeAyXOSJ8o2W1NY6run6y68c10jSnW/NdFzz3/csCmpbuHZxBgQAsIIAAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVvAwUqAVc0VGBVfXtj7EnYTOrjPOe0vcXtMEneBKcQYEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFbwMFKgFSt545ag6v572MLQNhJCMx6f7rgm+vNtTdAJrhRnQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQ8jRasUFhsbXF17T4g7adiRu9Ic19w1/TPHNf+RuMBxzTlX53fTD04kOq6J2/KN45qzjitwNXAGBACwggACAFjhKIDy8/N12223KTY2VomJiRozZoxKSkoCxpw+fVq5ubnq1KmTYmJiNG7cOFVWVoa0aQBAy+cogIqKipSbm6utW7dq3bp1qqur0/Dhw1VTU+MfM3PmTH388cdasWKFioqKdOjQIY0dOzbkjQMAWjZHNyGsXbs24PWSJUuUmJioHTt2aMiQIaqqqtK//du/admyZbrzzjslSYsXL9ZNN92krVu36p/+6Z9C1zkAoEW7omtAVVVVkqSOHTtKknbs2KG6ujplZWX5x9x4443q0qWLtmzZ0uA6amtr5fV6AyYAQOsXdAD5fD7NmDFDAwcOVO/evSVJFRUVioqKUvv27QPGJiUlqaKiosH15Ofny+Px+Ke0NOe3pwIAWp6gAyg3N1e7d+/W8uXLr6iB2bNnq6qqyj+Vl5df0foAAC1DUB9EzcvL0+rVq7Vp0yZ17tzZPz85OVlnzpzR8ePHA86CKisrlZyc3OC63G633G53MG0AAFowR2dAxhjl5eVp5cqV2rhxo9LT0wOW9+vXT5GRkdqwYYN/XklJiQ4cOKABAwaEpmMAQKvg6AwoNzdXy5Yt06pVqxQbG+u/ruPxeNS2bVt5PB796le/0qxZs9SxY0fFxcVp+vTpGjBgAHfAAQACOAqghQsXSpKGDh0aMH/x4sWaNGmSJOn3v/+9wsLCNG7cONXW1io7O1tvvvlmSJoFALQeLmOMsd3E+bxerzwej4ZqtCJckbbbuSaE9b0pqLo9uTGOa5LT/q/jmiMlCY5rJt9Z6LhGkp7q9HVQdQhOn88nOa7p8ouvQt8IQuqsqVOhVqmqqkpxcXGNjuNZcAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALAiqG9ERfPl6tfLcU3b3x8Jalv/nfFuUHWO9bk6m2nuak2d45pIV3hQ26qsr3VcM+dQjuOazq8F1x9aB86AAABWEEAAACsIIACAFQQQAMAKAggAYAUBBACwggACAFhBAAEArCCAAABWEEAAACsIIACAFQQQAMAKHkbaytR0jXFc8373fw1ya1FB1jU9n3xB1c08NNhxzW8S1zuuyf4813FNbGG045rqbo5LJEnps7cEUVXtuCJMO4PYDloLzoAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoeRtrKRH+4zXHNLw4/HNS2jv64neMaXxDPL61z/nxVvf1//uC8SFLpbacd10z78VTHNek7dzmukTGOS+KdbwW4ajgDAgBYQQABAKwggAAAVhBAAAArCCAAgBUEEADACgIIAGAFAQQAsIIAAgBYQQABAKwggAAAVhBAAAArXMYE8YTDJuT1euXxeDRUoxXhirTdDgDAobOmToVapaqqKsXFxTU6jjMgAIAVBBAAwApHAZSfn6/bbrtNsbGxSkxM1JgxY1RSUhIwZujQoXK5XAHTww8H930zAIDWy1EAFRUVKTc3V1u3btW6detUV1en4cOHq6amJmDclClTdPjwYf80b968kDYNAGj5HH0j6tq1awNeL1myRImJidqxY4eGDBninx8dHa3k5OTQdAgAaJWu6BpQVVWVJKljx44B85cuXar4+Hj17t1bs2fP1smTJxtdR21trbxeb8AEAGj9HJ0Bnc/n82nGjBkaOHCgevfu7Z9///33q2vXrkpNTdWuXbv01FNPqaSkRB9++GGD68nPz9eLL74YbBsAgBYq6M8BTZs2TZ988ok2b96szp07Nzpu48aNGjZsmPbt26eMjIwLltfW1qq2ttb/2uv1Ki0tjc8BAUALdbmfAwrqDCgvL0+rV6/Wpk2bLho+kpSZmSlJjQaQ2+2W2+0Opg0AQAvmKICMMZo+fbpWrlypwsJCpaenX7Jm586dkqSUlJSgGgQAtE6OAig3N1fLli3TqlWrFBsbq4qKCkmSx+NR27ZtVVpaqmXLlmnkyJHq1KmTdu3apZkzZ2rIkCHq06dPk/wAAICWydE1IJfL1eD8xYsXa9KkSSovL9eDDz6o3bt3q6amRmlpabr33nv17LPPXvR9wPPxLDgAaNma5BrQpbIqLS1NRUVFTlYJALhG8Sw4AIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAVEbYb+CFjjCTprOokY7kZAIBjZ1Un6X/+P29Mswug6upqSdJmrbHcCQDgSlRXV8vj8TS63GUuFVFXmc/n06FDhxQbGyuXyxWwzOv1Ki0tTeXl5YqLi7PUoX3sh3PYD+ewH85hP5zTHPaDMUbV1dVKTU1VWFjjV3qa3RlQWFiYOnfufNExcXFx1/QB9j32wznsh3PYD+ewH86xvR8udubzPW5CAABYQQABAKxoUQHkdrs1Z84cud1u261YxX44h/1wDvvhHPbDOS1pPzS7mxAAANeGFnUGBABoPQggAIAVBBAAwAoCCABgBQEEALCixQTQggUL1K1bN7Vp00aZmZn64osvbLd01b3wwgtyuVwB04033mi7rSa3adMmjRo1SqmpqXK5XProo48Clhtj9PzzzyslJUVt27ZVVlaW9u7da6fZJnSp/TBp0qQLjo8RI0bYabaJ5Ofn67bbblNsbKwSExM1ZswYlZSUBIw5ffq0cnNz1alTJ8XExGjcuHGqrKy01HHTuJz9MHTo0AuOh4cffthSxw1rEQH0/vvva9asWZozZ46+/PJL9e3bV9nZ2Tpy5Ijt1q66Xr166fDhw/5p8+bNtltqcjU1Nerbt68WLFjQ4PJ58+bp9ddf16JFi7Rt2za1a9dO2dnZOn369FXutGldaj9I0ogRIwKOj/fee+8qdtj0ioqKlJubq61bt2rdunWqq6vT8OHDVVNT4x8zc+ZMffzxx1qxYoWKiop06NAhjR071mLXoXc5+0GSpkyZEnA8zJs3z1LHjTAtQP/+/U1ubq7/dX19vUlNTTX5+fkWu7r65syZY/r27Wu7DaskmZUrV/pf+3w+k5ycbF5++WX/vOPHjxu3223ee+89Cx1eHT/cD8YYM3HiRDN69Ggr/dhy5MgRI8kUFRUZY8793UdGRpoVK1b4x/zjH/8wksyWLVtstdnkfrgfjDHm9ttvN4899pi9pi5Dsz8DOnPmjHbs2KGsrCz/vLCwMGVlZWnLli0WO7Nj7969Sk1NVffu3fXAAw/owIEDtluyqqysTBUVFQHHh8fjUWZm5jV5fBQWFioxMVE9e/bUtGnTdOzYMdstNamqqipJUseOHSVJO3bsUF1dXcDxcOONN6pLly6t+nj44X743tKlSxUfH6/evXtr9uzZOnnypI32GtXsnob9Q99++63q6+uVlJQUMD8pKUl79uyx1JUdmZmZWrJkiXr27KnDhw/rxRdf1ODBg7V7927Fxsbabs+KiooKSWrw+Ph+2bVixIgRGjt2rNLT01VaWqpnnnlGOTk52rJli8LDw223F3I+n08zZszQwIED1bt3b0nnjoeoqCi1b98+YGxrPh4a2g+SdP/996tr165KTU3Vrl279NRTT6mkpEQffvihxW4DNfsAwv/Iycnx/7lPnz7KzMxU165d9cEHH+hXv/qVxc7QHIwfP97/55tvvll9+vRRRkaGCgsLNWzYMIudNY3c3Fzt3r37mrgOejGN7YeHHnrI/+ebb75ZKSkpGjZsmEpLS5WRkXG122xQs38LLj4+XuHh4RfcxVJZWank5GRLXTUP7du3V48ePbRv3z7brVjz/THA8XGh7t27Kz4+vlUeH3l5eVq9erUKCgoCvj8sOTlZZ86c0fHjxwPGt9bjobH90JDMzExJalbHQ7MPoKioKPXr108bNmzwz/P5fNqwYYMGDBhgsTP7Tpw4odLSUqWkpNhuxZr09HQlJycHHB9er1fbtm275o+PgwcP6tixY63q+DDGKC8vTytXrtTGjRuVnp4esLxfv36KjIwMOB5KSkp04MCBVnU8XGo/NGTnzp2S1LyOB9t3QVyO5cuXG7fbbZYsWWL+/ve/m4ceesi0b9/eVFRU2G7tqnr88cdNYWGhKSsrM5999pnJysoy8fHx5siRI7Zba1LV1dWmuLjYFBcXG0nm1VdfNcXFxeabb74xxhjzu9/9zrRv396sWrXK7Nq1y4wePdqkp6ebU6dOWe48tC62H6qrq80TTzxhtmzZYsrKysz69evNT37yE3PDDTeY06dP2249ZKZNm2Y8Ho8pLCw0hw8f9k8nT570j3n44YdNly5dzMaNG8327dvNgAEDzIABAyx2HXqX2g/79u0zL730ktm+fbspKyszq1atMt27dzdDhgyx3HmgFhFAxhjzxhtvmC5dupioqCjTv39/s3XrVtstXXX33XefSUlJMVFRUea6664z9913n9m3b5/ttppcQUGBkXTBNHHiRGPMuVuxn3vuOZOUlGTcbrcZNmyYKSkpsdt0E7jYfjh58qQZPny4SUhIMJGRkaZr165mypQpre6XtIZ+fklm8eLF/jGnTp0yjzzyiOnQoYOJjo429957rzl8+LC9ppvApfbDgQMHzJAhQ0zHjh2N2+02119/vfnNb35jqqqq7Db+A3wfEADAimZ/DQgA0DoRQAAAKwggAIAVBBAAwAoCCABgBQEEALCCAAIAWEEAAQCsIIAAAFYQQAAAKwggAIAV/w/hgVLrpVGHsAAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure()\n",
    "plt.title(\"Prediction: {}\".format(prediction))\n",
    "plt.imshow(img)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6377f41a-5654-410b-8bad-d392e9dce7b8",
   "metadata": {
    "tags": []
   },
   "source": [
    "#### Stop Triton Server on each executor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d06de00e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-02-04 14:00:18,330 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n",
      "2025-02-04 14:00:28,520 - INFO - Sucessfully stopped 1 servers.                 \n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[True]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "server_manager.stop_servers()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "id": "f612dc0b-538f-4ecf-81f7-ef6b58c493ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\n",
    "    spark.stop()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "490fc849-e47a-48d7-accc-429ff1cced6b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "spark-dl-tf",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
